diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 16:04:21 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 16:04:21 +0000 |
commit | 8a754e0858d922e955e71b253c139e071ecec432 (patch) | |
tree | 527d16e74bfd1840c85efd675fdecad056c54107 /lib/ansible/parsing | |
parent | Initial commit. (diff) | |
download | ansible-core-upstream/2.14.3.tar.xz ansible-core-upstream/2.14.3.zip |
Adding upstream version 2.14.3.upstream/2.14.3upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'lib/ansible/parsing')
-rw-r--r-- | lib/ansible/parsing/__init__.py | 20 | ||||
-rw-r--r-- | lib/ansible/parsing/ajson.py | 42 | ||||
-rw-r--r-- | lib/ansible/parsing/dataloader.py | 468 | ||||
-rw-r--r-- | lib/ansible/parsing/mod_args.py | 345 | ||||
-rw-r--r-- | lib/ansible/parsing/plugin_docs.py | 227 | ||||
-rw-r--r-- | lib/ansible/parsing/quoting.py | 31 | ||||
-rw-r--r-- | lib/ansible/parsing/splitter.py | 286 | ||||
-rw-r--r-- | lib/ansible/parsing/utils/__init__.py | 20 | ||||
-rw-r--r-- | lib/ansible/parsing/utils/addresses.py | 216 | ||||
-rw-r--r-- | lib/ansible/parsing/utils/jsonify.py | 38 | ||||
-rw-r--r-- | lib/ansible/parsing/utils/yaml.py | 84 | ||||
-rw-r--r-- | lib/ansible/parsing/vault/__init__.py | 1289 | ||||
-rw-r--r-- | lib/ansible/parsing/yaml/__init__.py | 20 | ||||
-rw-r--r-- | lib/ansible/parsing/yaml/constructor.py | 178 | ||||
-rw-r--r-- | lib/ansible/parsing/yaml/dumper.py | 122 | ||||
-rw-r--r-- | lib/ansible/parsing/yaml/loader.py | 45 | ||||
-rw-r--r-- | lib/ansible/parsing/yaml/objects.py | 365 |
17 files changed, 3796 insertions, 0 deletions
diff --git a/lib/ansible/parsing/__init__.py b/lib/ansible/parsing/__init__.py new file mode 100644 index 0000000..28634b1 --- /dev/null +++ b/lib/ansible/parsing/__init__.py @@ -0,0 +1,20 @@ +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type diff --git a/lib/ansible/parsing/ajson.py b/lib/ansible/parsing/ajson.py new file mode 100644 index 0000000..8049755 --- /dev/null +++ b/lib/ansible/parsing/ajson.py @@ -0,0 +1,42 @@ +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +# Imported for backwards compat +from ansible.module_utils.common.json import AnsibleJSONEncoder + +from ansible.parsing.vault import VaultLib +from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode +from ansible.utils.unsafe_proxy import wrap_var + + +class AnsibleJSONDecoder(json.JSONDecoder): + + _vaults = {} # type: dict[str, VaultLib] + + def __init__(self, *args, **kwargs): + kwargs['object_hook'] = self.object_hook + super(AnsibleJSONDecoder, self).__init__(*args, **kwargs) + + @classmethod + def set_secrets(cls, secrets): + cls._vaults['default'] = VaultLib(secrets=secrets) + + def object_hook(self, pairs): + for key in pairs: + value = pairs[key] + + if key == '__ansible_vault': + value = AnsibleVaultEncryptedUnicode(value) + if self._vaults: + value.vault = self._vaults['default'] + return value + elif key == '__ansible_unsafe': + return wrap_var(value) + + return pairs diff --git a/lib/ansible/parsing/dataloader.py b/lib/ansible/parsing/dataloader.py new file mode 100644 index 0000000..cbba966 --- /dev/null +++ b/lib/ansible/parsing/dataloader.py @@ -0,0 +1,468 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import copy +import os +import os.path +import re +import tempfile + +from ansible import constants as C +from ansible.errors import AnsibleFileNotFound, AnsibleParserError +from ansible.module_utils.basic import is_executable +from ansible.module_utils.six import binary_type, text_type +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.parsing.quoting import unquote +from ansible.parsing.utils.yaml import from_yaml +from ansible.parsing.vault import VaultLib, b_HEADER, is_encrypted, is_encrypted_file, parse_vaulttext_envelope +from ansible.utils.path import unfrackpath +from ansible.utils.display import Display + +display = Display() + + +# Tries to determine if a path is inside a role, last dir must be 'tasks' +# this is not perfect but people should really avoid 'tasks' dirs outside roles when using Ansible. +RE_TASKS = re.compile(u'(?:^|%s)+tasks%s?$' % (os.path.sep, os.path.sep)) + + +class DataLoader: + + ''' + The DataLoader class is used to load and parse YAML or JSON content, + either from a given file name or from a string that was previously + read in through other means. A Vault password can be specified, and + any vault-encrypted files will be decrypted. + + Data read from files will also be cached, so the file will never be + read from disk more than once. + + Usage: + + dl = DataLoader() + # optionally: dl.set_vault_password('foo') + ds = dl.load('...') + ds = dl.load_from_file('/path/to/file') + ''' + + def __init__(self): + + self._basedir = '.' + + # NOTE: not effective with forks as the main copy does not get updated. + # avoids rereading files + self._FILE_CACHE = dict() + + # NOTE: not thread safe, also issues with forks not returning data to main proc + # so they need to be cleaned independently. See WorkerProcess for example. + # used to keep track of temp files for cleaning + self._tempfiles = set() + + # initialize the vault stuff with an empty password + # TODO: replace with a ref to something that can get the password + # a creds/auth provider + # self.set_vault_password(None) + self._vaults = {} + self._vault = VaultLib() + self.set_vault_secrets(None) + + # TODO: since we can query vault_secrets late, we could provide this to DataLoader init + def set_vault_secrets(self, vault_secrets): + self._vault.secrets = vault_secrets + + def load(self, data, file_name='<string>', show_content=True, json_only=False): + '''Backwards compat for now''' + return from_yaml(data, file_name, show_content, self._vault.secrets, json_only=json_only) + + def load_from_file(self, file_name, cache=True, unsafe=False, json_only=False): + ''' Loads data from a file, which can contain either JSON or YAML. ''' + + file_name = self.path_dwim(file_name) + display.debug("Loading data from %s" % file_name) + + # if the file has already been read in and cached, we'll + # return those results to avoid more file/vault operations + if cache and file_name in self._FILE_CACHE: + parsed_data = self._FILE_CACHE[file_name] + else: + # read the file contents and load the data structure from them + (b_file_data, show_content) = self._get_file_contents(file_name) + + file_data = to_text(b_file_data, errors='surrogate_or_strict') + parsed_data = self.load(data=file_data, file_name=file_name, show_content=show_content, json_only=json_only) + + # cache the file contents for next time + self._FILE_CACHE[file_name] = parsed_data + + if unsafe: + return parsed_data + else: + # return a deep copy here, so the cache is not affected + return copy.deepcopy(parsed_data) + + def path_exists(self, path): + path = self.path_dwim(path) + return os.path.exists(to_bytes(path, errors='surrogate_or_strict')) + + def is_file(self, path): + path = self.path_dwim(path) + return os.path.isfile(to_bytes(path, errors='surrogate_or_strict')) or path == os.devnull + + def is_directory(self, path): + path = self.path_dwim(path) + return os.path.isdir(to_bytes(path, errors='surrogate_or_strict')) + + def list_directory(self, path): + path = self.path_dwim(path) + return os.listdir(path) + + def is_executable(self, path): + '''is the given path executable?''' + path = self.path_dwim(path) + return is_executable(path) + + def _decrypt_if_vault_data(self, b_vault_data, b_file_name=None): + '''Decrypt b_vault_data if encrypted and return b_data and the show_content flag''' + + if not is_encrypted(b_vault_data): + show_content = True + return b_vault_data, show_content + + b_ciphertext, b_version, cipher_name, vault_id = parse_vaulttext_envelope(b_vault_data) + b_data = self._vault.decrypt(b_vault_data, filename=b_file_name) + + show_content = False + return b_data, show_content + + def _get_file_contents(self, file_name): + ''' + Reads the file contents from the given file name + + If the contents are vault-encrypted, it will decrypt them and return + the decrypted data + + :arg file_name: The name of the file to read. If this is a relative + path, it will be expanded relative to the basedir + :raises AnsibleFileNotFound: if the file_name does not refer to a file + :raises AnsibleParserError: if we were unable to read the file + :return: Returns a byte string of the file contents + ''' + if not file_name or not isinstance(file_name, (binary_type, text_type)): + raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_name)) + + b_file_name = to_bytes(self.path_dwim(file_name)) + # This is what we really want but have to fix unittests to make it pass + # if not os.path.exists(b_file_name) or not os.path.isfile(b_file_name): + if not self.path_exists(b_file_name): + raise AnsibleFileNotFound("Unable to retrieve file contents", file_name=file_name) + + try: + with open(b_file_name, 'rb') as f: + data = f.read() + return self._decrypt_if_vault_data(data, b_file_name) + except (IOError, OSError) as e: + raise AnsibleParserError("an error occurred while trying to read the file '%s': %s" % (file_name, to_native(e)), orig_exc=e) + + def get_basedir(self): + ''' returns the current basedir ''' + return self._basedir + + def set_basedir(self, basedir): + ''' sets the base directory, used to find files when a relative path is given ''' + + if basedir is not None: + self._basedir = to_text(basedir) + + def path_dwim(self, given): + ''' + make relative paths work like folks expect. + ''' + + given = unquote(given) + given = to_text(given, errors='surrogate_or_strict') + + if given.startswith(to_text(os.path.sep)) or given.startswith(u'~'): + path = given + else: + basedir = to_text(self._basedir, errors='surrogate_or_strict') + path = os.path.join(basedir, given) + + return unfrackpath(path, follow=False) + + def _is_role(self, path): + ''' imperfect role detection, roles are still valid w/o tasks|meta/main.yml|yaml|etc ''' + + b_path = to_bytes(path, errors='surrogate_or_strict') + b_path_dirname = os.path.dirname(b_path) + b_upath = to_bytes(unfrackpath(path, follow=False), errors='surrogate_or_strict') + + untasked_paths = ( + os.path.join(b_path, b'main.yml'), + os.path.join(b_path, b'main.yaml'), + os.path.join(b_path, b'main'), + ) + tasked_paths = ( + os.path.join(b_upath, b'tasks/main.yml'), + os.path.join(b_upath, b'tasks/main.yaml'), + os.path.join(b_upath, b'tasks/main'), + os.path.join(b_upath, b'meta/main.yml'), + os.path.join(b_upath, b'meta/main.yaml'), + os.path.join(b_upath, b'meta/main'), + os.path.join(b_path_dirname, b'tasks/main.yml'), + os.path.join(b_path_dirname, b'tasks/main.yaml'), + os.path.join(b_path_dirname, b'tasks/main'), + os.path.join(b_path_dirname, b'meta/main.yml'), + os.path.join(b_path_dirname, b'meta/main.yaml'), + os.path.join(b_path_dirname, b'meta/main'), + ) + + exists_untasked = map(os.path.exists, untasked_paths) + exists_tasked = map(os.path.exists, tasked_paths) + if RE_TASKS.search(path) and any(exists_untasked) or any(exists_tasked): + return True + + return False + + def path_dwim_relative(self, path, dirname, source, is_role=False): + ''' + find one file in either a role or playbook dir with or without + explicitly named dirname subdirs + + Used in action plugins and lookups to find supplemental files that + could be in either place. + ''' + + search = [] + source = to_text(source, errors='surrogate_or_strict') + + # I have full path, nothing else needs to be looked at + if source.startswith(to_text(os.path.sep)) or source.startswith(u'~'): + search.append(unfrackpath(source, follow=False)) + else: + # base role/play path + templates/files/vars + relative filename + search.append(os.path.join(path, dirname, source)) + basedir = unfrackpath(path, follow=False) + + # not told if role, but detect if it is a role and if so make sure you get correct base path + if not is_role: + is_role = self._is_role(path) + + if is_role and RE_TASKS.search(path): + basedir = unfrackpath(os.path.dirname(path), follow=False) + + cur_basedir = self._basedir + self.set_basedir(basedir) + # resolved base role/play path + templates/files/vars + relative filename + search.append(unfrackpath(os.path.join(basedir, dirname, source), follow=False)) + self.set_basedir(cur_basedir) + + if is_role and not source.endswith(dirname): + # look in role's tasks dir w/o dirname + search.append(unfrackpath(os.path.join(basedir, 'tasks', source), follow=False)) + + # try to create absolute path for loader basedir + templates/files/vars + filename + search.append(unfrackpath(os.path.join(dirname, source), follow=False)) + + # try to create absolute path for loader basedir + search.append(unfrackpath(os.path.join(basedir, source), follow=False)) + + # try to create absolute path for dirname + filename + search.append(self.path_dwim(os.path.join(dirname, source))) + + # try to create absolute path for filename + search.append(self.path_dwim(source)) + + for candidate in search: + if os.path.exists(to_bytes(candidate, errors='surrogate_or_strict')): + break + + return candidate + + def path_dwim_relative_stack(self, paths, dirname, source, is_role=False): + ''' + find one file in first path in stack taking roles into account and adding play basedir as fallback + + :arg paths: A list of text strings which are the paths to look for the filename in. + :arg dirname: A text string representing a directory. The directory + is prepended to the source to form the path to search for. + :arg source: A text string which is the filename to search for + :rtype: A text string + :returns: An absolute path to the filename ``source`` if found + :raises: An AnsibleFileNotFound Exception if the file is found to exist in the search paths + ''' + b_dirname = to_bytes(dirname, errors='surrogate_or_strict') + b_source = to_bytes(source, errors='surrogate_or_strict') + + result = None + search = [] + if source is None: + display.warning('Invalid request to find a file that matches a "null" value') + elif source and (source.startswith('~') or source.startswith(os.path.sep)): + # path is absolute, no relative needed, check existence and return source + test_path = unfrackpath(b_source, follow=False) + if os.path.exists(to_bytes(test_path, errors='surrogate_or_strict')): + result = test_path + else: + display.debug(u'evaluation_path:\n\t%s' % '\n\t'.join(paths)) + for path in paths: + upath = unfrackpath(path, follow=False) + b_upath = to_bytes(upath, errors='surrogate_or_strict') + b_pb_base_dir = os.path.dirname(b_upath) + + # if path is in role and 'tasks' not there already, add it into the search + if (is_role or self._is_role(path)) and b_pb_base_dir.endswith(b'/tasks'): + search.append(os.path.join(os.path.dirname(b_pb_base_dir), b_dirname, b_source)) + search.append(os.path.join(b_pb_base_dir, b_source)) + else: + # don't add dirname if user already is using it in source + if b_source.split(b'/')[0] != dirname: + search.append(os.path.join(b_upath, b_dirname, b_source)) + search.append(os.path.join(b_upath, b_source)) + + # always append basedir as last resort + # don't add dirname if user already is using it in source + if b_source.split(b'/')[0] != dirname: + search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_dirname, b_source)) + search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_source)) + + display.debug(u'search_path:\n\t%s' % to_text(b'\n\t'.join(search))) + for b_candidate in search: + display.vvvvv(u'looking for "%s" at "%s"' % (source, to_text(b_candidate))) + if os.path.exists(b_candidate): + result = to_text(b_candidate) + break + + if result is None: + raise AnsibleFileNotFound(file_name=source, paths=[to_native(p) for p in search]) + + return result + + def _create_content_tempfile(self, content): + ''' Create a tempfile containing defined content ''' + fd, content_tempfile = tempfile.mkstemp(dir=C.DEFAULT_LOCAL_TMP) + f = os.fdopen(fd, 'wb') + content = to_bytes(content) + try: + f.write(content) + except Exception as err: + os.remove(content_tempfile) + raise Exception(err) + finally: + f.close() + return content_tempfile + + def get_real_file(self, file_path, decrypt=True): + """ + If the file is vault encrypted return a path to a temporary decrypted file + If the file is not encrypted then the path is returned + Temporary files are cleanup in the destructor + """ + + if not file_path or not isinstance(file_path, (binary_type, text_type)): + raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_path)) + + b_file_path = to_bytes(file_path, errors='surrogate_or_strict') + if not self.path_exists(b_file_path) or not self.is_file(b_file_path): + raise AnsibleFileNotFound(file_name=file_path) + + real_path = self.path_dwim(file_path) + + try: + if decrypt: + with open(to_bytes(real_path), 'rb') as f: + # Limit how much of the file is read since we do not know + # whether this is a vault file and therefore it could be very + # large. + if is_encrypted_file(f, count=len(b_HEADER)): + # if the file is encrypted and no password was specified, + # the decrypt call would throw an error, but we check first + # since the decrypt function doesn't know the file name + data = f.read() + if not self._vault.secrets: + raise AnsibleParserError("A vault password or secret must be specified to decrypt %s" % to_native(file_path)) + + data = self._vault.decrypt(data, filename=real_path) + # Make a temp file + real_path = self._create_content_tempfile(data) + self._tempfiles.add(real_path) + + return real_path + + except (IOError, OSError) as e: + raise AnsibleParserError("an error occurred while trying to read the file '%s': %s" % (to_native(real_path), to_native(e)), orig_exc=e) + + def cleanup_tmp_file(self, file_path): + """ + Removes any temporary files created from a previous call to + get_real_file. file_path must be the path returned from a + previous call to get_real_file. + """ + if file_path in self._tempfiles: + os.unlink(file_path) + self._tempfiles.remove(file_path) + + def cleanup_all_tmp_files(self): + """ + Removes all temporary files that DataLoader has created + NOTE: not thread safe, forks also need special handling see __init__ for details. + """ + for f in list(self._tempfiles): + try: + self.cleanup_tmp_file(f) + except Exception as e: + display.warning("Unable to cleanup temp files: %s" % to_text(e)) + + def find_vars_files(self, path, name, extensions=None, allow_dir=True): + """ + Find vars files in a given path with specified name. This will find + files in a dir named <name>/ or a file called <name> ending in known + extensions. + """ + + b_path = to_bytes(os.path.join(path, name)) + found = [] + + if extensions is None: + # Look for file with no extension first to find dir before file + extensions = [''] + C.YAML_FILENAME_EXTENSIONS + # add valid extensions to name + for ext in extensions: + + if '.' in ext: + full_path = b_path + to_bytes(ext) + elif ext: + full_path = b'.'.join([b_path, to_bytes(ext)]) + else: + full_path = b_path + + if self.path_exists(full_path): + if self.is_directory(full_path): + if allow_dir: + found.extend(self._get_dir_vars_files(to_text(full_path), extensions)) + else: + continue + else: + found.append(full_path) + break + return found + + def _get_dir_vars_files(self, path, extensions): + found = [] + for spath in sorted(self.list_directory(path)): + if not spath.startswith(u'.') and not spath.endswith(u'~'): # skip hidden and backups + + ext = os.path.splitext(spath)[-1] + full_spath = os.path.join(path, spath) + + if self.is_directory(full_spath) and not ext: # recursive search if dir + found.extend(self._get_dir_vars_files(full_spath, extensions)) + elif self.is_file(full_spath) and (not ext or to_text(ext) in extensions): + # only consider files with valid extensions or no extension + found.append(full_spath) + + return found diff --git a/lib/ansible/parsing/mod_args.py b/lib/ansible/parsing/mod_args.py new file mode 100644 index 0000000..aeb58b0 --- /dev/null +++ b/lib/ansible/parsing/mod_args.py @@ -0,0 +1,345 @@ +# (c) 2014 Michael DeHaan, <michael@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.constants as C +from ansible.errors import AnsibleParserError, AnsibleError, AnsibleAssertionError +from ansible.module_utils.six import string_types +from ansible.module_utils._text import to_text +from ansible.parsing.splitter import parse_kv, split_args +from ansible.plugins.loader import module_loader, action_loader +from ansible.template import Templar +from ansible.utils.fqcn import add_internal_fqcns +from ansible.utils.sentinel import Sentinel + + +# For filtering out modules correctly below +FREEFORM_ACTIONS = frozenset(C.MODULE_REQUIRE_ARGS) + +RAW_PARAM_MODULES = FREEFORM_ACTIONS.union(add_internal_fqcns(( + 'include', + 'include_vars', + 'include_tasks', + 'include_role', + 'import_tasks', + 'import_role', + 'add_host', + 'group_by', + 'set_fact', + 'meta', +))) + +BUILTIN_TASKS = frozenset(add_internal_fqcns(( + 'meta', + 'include', + 'include_tasks', + 'include_role', + 'import_tasks', + 'import_role' +))) + + +class ModuleArgsParser: + + """ + There are several ways a module and argument set can be expressed: + + # legacy form (for a shell command) + - action: shell echo hi + + # common shorthand for local actions vs delegate_to + - local_action: shell echo hi + + # most commonly: + - copy: src=a dest=b + + # legacy form + - action: copy src=a dest=b + + # complex args form, for passing structured data + - copy: + src: a + dest: b + + # gross, but technically legal + - action: + module: copy + args: + src: a + dest: b + + # Standard YAML form for command-type modules. In this case, the args specified + # will act as 'defaults' and will be overridden by any args specified + # in one of the other formats (complex args under the action, or + # parsed from the k=v string + - command: 'pwd' + args: + chdir: '/tmp' + + + This class has some of the logic to canonicalize these into the form + + - module: <module_name> + delegate_to: <optional> + args: <args> + + Args may also be munged for certain shell command parameters. + """ + + def __init__(self, task_ds=None, collection_list=None): + task_ds = {} if task_ds is None else task_ds + + if not isinstance(task_ds, dict): + raise AnsibleAssertionError("the type of 'task_ds' should be a dict, but is a %s" % type(task_ds)) + self._task_ds = task_ds + self._collection_list = collection_list + # delayed local imports to prevent circular import + from ansible.playbook.task import Task + from ansible.playbook.handler import Handler + # store the valid Task/Handler attrs for quick access + self._task_attrs = set(Task.fattributes) + self._task_attrs.update(set(Handler.fattributes)) + # HACK: why are these not FieldAttributes on task with a post-validate to check usage? + self._task_attrs.update(['local_action', 'static']) + self._task_attrs = frozenset(self._task_attrs) + + self.resolved_action = None + + def _split_module_string(self, module_string): + ''' + when module names are expressed like: + action: copy src=a dest=b + the first part of the string is the name of the module + and the rest are strings pertaining to the arguments. + ''' + + tokens = split_args(module_string) + if len(tokens) > 1: + return (tokens[0].strip(), " ".join(tokens[1:])) + else: + return (tokens[0].strip(), "") + + def _normalize_parameters(self, thing, action=None, additional_args=None): + ''' + arguments can be fuzzy. Deal with all the forms. + ''' + + additional_args = {} if additional_args is None else additional_args + + # final args are the ones we'll eventually return, so first update + # them with any additional args specified, which have lower priority + # than those which may be parsed/normalized next + final_args = dict() + if additional_args: + if isinstance(additional_args, string_types): + templar = Templar(loader=None) + if templar.is_template(additional_args): + final_args['_variable_params'] = additional_args + else: + raise AnsibleParserError("Complex args containing variables cannot use bare variables (without Jinja2 delimiters), " + "and must use the full variable style ('{{var_name}}')") + elif isinstance(additional_args, dict): + final_args.update(additional_args) + else: + raise AnsibleParserError('Complex args must be a dictionary or variable string ("{{var}}").') + + # how we normalize depends if we figured out what the module name is + # yet. If we have already figured it out, it's a 'new style' invocation. + # otherwise, it's not + + if action is not None: + args = self._normalize_new_style_args(thing, action) + else: + (action, args) = self._normalize_old_style_args(thing) + + # this can occasionally happen, simplify + if args and 'args' in args: + tmp_args = args.pop('args') + if isinstance(tmp_args, string_types): + tmp_args = parse_kv(tmp_args) + args.update(tmp_args) + + # only internal variables can start with an underscore, so + # we don't allow users to set them directly in arguments + if args and action not in FREEFORM_ACTIONS: + for arg in args: + arg = to_text(arg) + if arg.startswith('_ansible_'): + raise AnsibleError("invalid parameter specified for action '%s': '%s'" % (action, arg)) + + # finally, update the args we're going to return with the ones + # which were normalized above + if args: + final_args.update(args) + + return (action, final_args) + + def _normalize_new_style_args(self, thing, action): + ''' + deals with fuzziness in new style module invocations + accepting key=value pairs and dictionaries, and returns + a dictionary of arguments + + possible example inputs: + 'echo hi', 'shell' + {'region': 'xyz'}, 'ec2' + standardized outputs like: + { _raw_params: 'echo hi', _uses_shell: True } + ''' + + if isinstance(thing, dict): + # form is like: { xyz: { x: 2, y: 3 } } + args = thing + elif isinstance(thing, string_types): + # form is like: copy: src=a dest=b + check_raw = action in FREEFORM_ACTIONS + args = parse_kv(thing, check_raw=check_raw) + elif thing is None: + # this can happen with modules which take no params, like ping: + args = None + else: + raise AnsibleParserError("unexpected parameter type in action: %s" % type(thing), obj=self._task_ds) + return args + + def _normalize_old_style_args(self, thing): + ''' + deals with fuzziness in old-style (action/local_action) module invocations + returns tuple of (module_name, dictionary_args) + + possible example inputs: + { 'shell' : 'echo hi' } + 'shell echo hi' + {'module': 'ec2', 'x': 1 } + standardized outputs like: + ('ec2', { 'x': 1} ) + ''' + + action = None + args = None + + if isinstance(thing, dict): + # form is like: action: { module: 'copy', src: 'a', dest: 'b' } + thing = thing.copy() + if 'module' in thing: + action, module_args = self._split_module_string(thing['module']) + args = thing.copy() + check_raw = action in FREEFORM_ACTIONS + args.update(parse_kv(module_args, check_raw=check_raw)) + del args['module'] + + elif isinstance(thing, string_types): + # form is like: action: copy src=a dest=b + (action, args) = self._split_module_string(thing) + check_raw = action in FREEFORM_ACTIONS + args = parse_kv(args, check_raw=check_raw) + + else: + # need a dict or a string, so giving up + raise AnsibleParserError("unexpected parameter type in action: %s" % type(thing), obj=self._task_ds) + + return (action, args) + + def parse(self, skip_action_validation=False): + ''' + Given a task in one of the supported forms, parses and returns + returns the action, arguments, and delegate_to values for the + task, dealing with all sorts of levels of fuzziness. + ''' + + thing = None + + action = None + delegate_to = self._task_ds.get('delegate_to', Sentinel) + args = dict() + + # This is the standard YAML form for command-type modules. We grab + # the args and pass them in as additional arguments, which can/will + # be overwritten via dict updates from the other arg sources below + additional_args = self._task_ds.get('args', dict()) + + # We can have one of action, local_action, or module specified + # action + if 'action' in self._task_ds: + # an old school 'action' statement + thing = self._task_ds['action'] + action, args = self._normalize_parameters(thing, action=action, additional_args=additional_args) + + # local_action + if 'local_action' in self._task_ds: + # local_action is similar but also implies a delegate_to + if action is not None: + raise AnsibleParserError("action and local_action are mutually exclusive", obj=self._task_ds) + thing = self._task_ds.get('local_action', '') + delegate_to = 'localhost' + action, args = self._normalize_parameters(thing, action=action, additional_args=additional_args) + + # module: <stuff> is the more new-style invocation + + # filter out task attributes so we're only querying unrecognized keys as actions/modules + non_task_ds = dict((k, v) for k, v in self._task_ds.items() if (k not in self._task_attrs) and (not k.startswith('with_'))) + + # walk the filtered input dictionary to see if we recognize a module name + for item, value in non_task_ds.items(): + context = None + is_action_candidate = False + if item in BUILTIN_TASKS: + is_action_candidate = True + elif skip_action_validation: + is_action_candidate = True + else: + context = action_loader.find_plugin_with_context(item, collection_list=self._collection_list) + if not context.resolved: + context = module_loader.find_plugin_with_context(item, collection_list=self._collection_list) + + is_action_candidate = context.resolved and bool(context.redirect_list) + + if is_action_candidate: + # finding more than one module name is a problem + if action is not None: + raise AnsibleParserError("conflicting action statements: %s, %s" % (action, item), obj=self._task_ds) + + if context is not None and context.resolved: + self.resolved_action = context.resolved_fqcn + + action = item + thing = value + action, args = self._normalize_parameters(thing, action=action, additional_args=additional_args) + + # if we didn't see any module in the task at all, it's not a task really + if action is None: + if non_task_ds: # there was one non-task action, but we couldn't find it + bad_action = list(non_task_ds.keys())[0] + raise AnsibleParserError("couldn't resolve module/action '{0}'. This often indicates a " + "misspelling, missing collection, or incorrect module path.".format(bad_action), + obj=self._task_ds) + else: + raise AnsibleParserError("no module/action detected in task.", + obj=self._task_ds) + elif args.get('_raw_params', '') != '' and action not in RAW_PARAM_MODULES: + templar = Templar(loader=None) + raw_params = args.pop('_raw_params') + if templar.is_template(raw_params): + args['_variable_params'] = raw_params + else: + raise AnsibleParserError("this task '%s' has extra params, which is only allowed in the following modules: %s" % (action, + ", ".join(RAW_PARAM_MODULES)), + obj=self._task_ds) + + return (action, args, delegate_to) diff --git a/lib/ansible/parsing/plugin_docs.py b/lib/ansible/parsing/plugin_docs.py new file mode 100644 index 0000000..cda5463 --- /dev/null +++ b/lib/ansible/parsing/plugin_docs.py @@ -0,0 +1,227 @@ +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ast +import tokenize + +from ansible import constants as C +from ansible.errors import AnsibleError, AnsibleParserError +from ansible.module_utils._text import to_text, to_native +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.utils.display import Display + +display = Display() + + +string_to_vars = { + 'DOCUMENTATION': 'doc', + 'EXAMPLES': 'plainexamples', + 'RETURN': 'returndocs', + 'ANSIBLE_METADATA': 'metadata', # NOTE: now unused, but kept for backwards compat +} + + +def _var2string(value): + ''' reverse lookup of the dict above ''' + for k, v in string_to_vars.items(): + if v == value: + return k + + +def _init_doc_dict(): + ''' initialize a return dict for docs with the expected structure ''' + return {k: None for k in string_to_vars.values()} + + +def read_docstring_from_yaml_file(filename, verbose=True, ignore_errors=True): + ''' Read docs from 'sidecar' yaml file doc for a plugin ''' + + data = _init_doc_dict() + file_data = {} + + try: + with open(filename, 'rb') as yamlfile: + file_data = AnsibleLoader(yamlfile.read(), file_name=filename).get_single_data() + except Exception as e: + msg = "Unable to parse yaml file '%s': %s" % (filename, to_native(e)) + if not ignore_errors: + raise AnsibleParserError(msg, orig_exc=e) + elif verbose: + display.error(msg) + + if file_data: + for key in string_to_vars: + data[string_to_vars[key]] = file_data.get(key, None) + + return data + + +def read_docstring_from_python_module(filename, verbose=True, ignore_errors=True): + """ + Use tokenization to search for assignment of the documentation variables in the given file. + Parse from YAML and return the resulting python structure or None together with examples as plain text. + """ + + seen = set() + data = _init_doc_dict() + + next_string = None + with tokenize.open(filename) as f: + tokens = tokenize.generate_tokens(f.readline) + for token in tokens: + + # found lable that looks like variable + if token.type == tokenize.NAME: + + # label is expected value, in correct place and has not been seen before + if token.start == 1 and token.string in string_to_vars and token.string not in seen: + # next token that is string has the docs + next_string = string_to_vars[token.string] + continue + + # previous token indicated this string is a doc string + if next_string is not None and token.type == tokenize.STRING: + + # ensure we only process one case of it + seen.add(token.string) + + value = token.string + + # strip string modifiers/delimiters + if value.startswith(('r', 'b')): + value = value.lstrip('rb') + + if value.startswith(("'", '"')): + value = value.strip("'\"") + + # actually use the data + if next_string == 'plainexamples': + # keep as string, can be yaml, but we let caller deal with it + data[next_string] = to_text(value) + else: + # yaml load the data + try: + data[next_string] = AnsibleLoader(value, file_name=filename).get_single_data() + except Exception as e: + msg = "Unable to parse docs '%s' in python file '%s': %s" % (_var2string(next_string), filename, to_native(e)) + if not ignore_errors: + raise AnsibleParserError(msg, orig_exc=e) + elif verbose: + display.error(msg) + + next_string = None + + # if nothing else worked, fall back to old method + if not seen: + data = read_docstring_from_python_file(filename, verbose, ignore_errors) + + return data + + +def read_docstring_from_python_file(filename, verbose=True, ignore_errors=True): + """ + Use ast to search for assignment of the DOCUMENTATION and EXAMPLES variables in the given file. + Parse DOCUMENTATION from YAML and return the YAML doc or None together with EXAMPLES, as plain text. + """ + + data = _init_doc_dict() + + try: + with open(filename, 'rb') as b_module_data: + M = ast.parse(b_module_data.read()) + + for child in M.body: + if isinstance(child, ast.Assign): + for t in child.targets: + try: + theid = t.id + except AttributeError: + # skip errors can happen when trying to use the normal code + display.warning("Building documentation, failed to assign id for %s on %s, skipping" % (t, filename)) + continue + + if theid in string_to_vars: + varkey = string_to_vars[theid] + if isinstance(child.value, ast.Dict): + data[varkey] = ast.literal_eval(child.value) + else: + if theid == 'EXAMPLES': + # examples 'can' be yaml, but even if so, we dont want to parse as such here + # as it can create undesired 'objects' that don't display well as docs. + data[varkey] = to_text(child.value.s) + else: + # string should be yaml if already not a dict + data[varkey] = AnsibleLoader(child.value.s, file_name=filename).get_single_data() + + display.debug('Documentation assigned: %s' % varkey) + + except Exception as e: + msg = "Unable to parse documentation in python file '%s': %s" % (filename, to_native(e)) + if not ignore_errors: + raise AnsibleParserError(msg, orig_exc=e) + elif verbose: + display.error(msg) + + return data + + +def read_docstring(filename, verbose=True, ignore_errors=True): + ''' returns a documentation dictionary from Ansible plugin docstrings ''' + + # NOTE: adjacency of doc file to code file is responsibility of caller + if filename.endswith(C.YAML_DOC_EXTENSIONS): + docstring = read_docstring_from_yaml_file(filename, verbose=verbose, ignore_errors=ignore_errors) + elif filename.endswith(C.PYTHON_DOC_EXTENSIONS): + docstring = read_docstring_from_python_module(filename, verbose=verbose, ignore_errors=ignore_errors) + elif not ignore_errors: + raise AnsibleError("Unknown documentation format: %s" % to_native(filename)) + + if not docstring and not ignore_errors: + raise AnsibleError("Unable to parse documentation for: %s" % to_native(filename)) + + # cause seealso is specially processed from 'doc' later on + # TODO: stop any other 'overloaded' implementation in main doc + docstring['seealso'] = None + + return docstring + + +def read_docstub(filename): + """ + Quickly find short_description using string methods instead of node parsing. + This does not return a full set of documentation strings and is intended for + operations like ansible-doc -l. + """ + + in_documentation = False + capturing = False + indent_detection = '' + doc_stub = [] + + with open(filename, 'r') as t_module_data: + for line in t_module_data: + if in_documentation: + # start capturing the stub until indentation returns + if capturing and line.startswith(indent_detection): + doc_stub.append(line) + + elif capturing and not line.startswith(indent_detection): + break + + elif line.lstrip().startswith('short_description:'): + capturing = True + # Detect that the short_description continues on the next line if it's indented more + # than short_description itself. + indent_detection = ' ' * (len(line) - len(line.lstrip()) + 1) + doc_stub.append(line) + + elif line.startswith('DOCUMENTATION') and ('=' in line or ':' in line): + in_documentation = True + + short_description = r''.join(doc_stub).strip().rstrip('.') + data = AnsibleLoader(short_description, file_name=filename).get_single_data() + + return data diff --git a/lib/ansible/parsing/quoting.py b/lib/ansible/parsing/quoting.py new file mode 100644 index 0000000..d3a38d9 --- /dev/null +++ b/lib/ansible/parsing/quoting.py @@ -0,0 +1,31 @@ +# (c) 2014 James Cammarata, <jcammarata@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +def is_quoted(data): + return len(data) > 1 and data[0] == data[-1] and data[0] in ('"', "'") and data[-2] != '\\' + + +def unquote(data): + ''' removes first and last quotes from a string, if the string starts and ends with the same quotes ''' + if is_quoted(data): + return data[1:-1] + return data diff --git a/lib/ansible/parsing/splitter.py b/lib/ansible/parsing/splitter.py new file mode 100644 index 0000000..b68444f --- /dev/null +++ b/lib/ansible/parsing/splitter.py @@ -0,0 +1,286 @@ +# (c) 2014 James Cammarata, <jcammarata@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import codecs +import re + +from ansible.errors import AnsibleParserError +from ansible.module_utils._text import to_text +from ansible.parsing.quoting import unquote + + +# Decode escapes adapted from rspeer's answer here: +# http://stackoverflow.com/questions/4020539/process-escape-sequences-in-a-string-in-python +_HEXCHAR = '[a-fA-F0-9]' +_ESCAPE_SEQUENCE_RE = re.compile(r''' + ( \\U{0} # 8-digit hex escapes + | \\u{1} # 4-digit hex escapes + | \\x{2} # 2-digit hex escapes + | \\N\{{[^}}]+\}} # Unicode characters by name + | \\[\\'"abfnrtv] # Single-character escapes + )'''.format(_HEXCHAR * 8, _HEXCHAR * 4, _HEXCHAR * 2), re.UNICODE | re.VERBOSE) + + +def _decode_escapes(s): + def decode_match(match): + return codecs.decode(match.group(0), 'unicode-escape') + + return _ESCAPE_SEQUENCE_RE.sub(decode_match, s) + + +def parse_kv(args, check_raw=False): + ''' + Convert a string of key/value items to a dict. If any free-form params + are found and the check_raw option is set to True, they will be added + to a new parameter called '_raw_params'. If check_raw is not enabled, + they will simply be ignored. + ''' + + args = to_text(args, nonstring='passthru') + + options = {} + if args is not None: + try: + vargs = split_args(args) + except IndexError as e: + raise AnsibleParserError("Unable to parse argument string", orig_exc=e) + except ValueError as ve: + if 'no closing quotation' in str(ve).lower(): + raise AnsibleParserError("error parsing argument string, try quoting the entire line.", orig_exc=ve) + else: + raise + + raw_params = [] + for orig_x in vargs: + x = _decode_escapes(orig_x) + if "=" in x: + pos = 0 + try: + while True: + pos = x.index('=', pos + 1) + if pos > 0 and x[pos - 1] != '\\': + break + except ValueError: + # ran out of string, but we must have some escaped equals, + # so replace those and append this to the list of raw params + raw_params.append(x.replace('\\=', '=')) + continue + + k = x[:pos] + v = x[pos + 1:] + + # FIXME: make the retrieval of this list of shell/command options a function, so the list is centralized + if check_raw and k not in ('creates', 'removes', 'chdir', 'executable', 'warn', 'stdin', 'stdin_add_newline', 'strip_empty_ends'): + raw_params.append(orig_x) + else: + options[k.strip()] = unquote(v.strip()) + else: + raw_params.append(orig_x) + + # recombine the free-form params, if any were found, and assign + # them to a special option for use later by the shell/command module + if len(raw_params) > 0: + options[u'_raw_params'] = join_args(raw_params) + + return options + + +def _get_quote_state(token, quote_char): + ''' + the goal of this block is to determine if the quoted string + is unterminated in which case it needs to be put back together + ''' + # the char before the current one, used to see if + # the current character is escaped + prev_char = None + for idx, cur_char in enumerate(token): + if idx > 0: + prev_char = token[idx - 1] + if cur_char in '"\'' and prev_char != '\\': + if quote_char: + if cur_char == quote_char: + quote_char = None + else: + quote_char = cur_char + return quote_char + + +def _count_jinja2_blocks(token, cur_depth, open_token, close_token): + ''' + this function counts the number of opening/closing blocks for a + given opening/closing type and adjusts the current depth for that + block based on the difference + ''' + num_open = token.count(open_token) + num_close = token.count(close_token) + if num_open != num_close: + cur_depth += (num_open - num_close) + if cur_depth < 0: + cur_depth = 0 + return cur_depth + + +def join_args(s): + ''' + Join the original cmd based on manipulations by split_args(). + This retains the original newlines and whitespaces. + ''' + result = '' + for p in s: + if len(result) == 0 or result.endswith('\n'): + result += p + else: + result += ' ' + p + return result + + +def split_args(args): + ''' + Splits args on whitespace, but intelligently reassembles + those that may have been split over a jinja2 block or quotes. + + When used in a remote module, we won't ever have to be concerned about + jinja2 blocks, however this function is/will be used in the + core portions as well before the args are templated. + + example input: a=b c="foo bar" + example output: ['a=b', 'c="foo bar"'] + + Basically this is a variation shlex that has some more intelligence for + how Ansible needs to use it. + ''' + + # the list of params parsed out of the arg string + # this is going to be the result value when we are done + params = [] + + # Initial split on newlines + items = args.split('\n') + + # iterate over the tokens, and reassemble any that may have been + # split on a space inside a jinja2 block. + # ex if tokens are "{{", "foo", "}}" these go together + + # These variables are used + # to keep track of the state of the parsing, since blocks and quotes + # may be nested within each other. + + quote_char = None + inside_quotes = False + print_depth = 0 # used to count nested jinja2 {{ }} blocks + block_depth = 0 # used to count nested jinja2 {% %} blocks + comment_depth = 0 # used to count nested jinja2 {# #} blocks + + # now we loop over each split chunk, coalescing tokens if the white space + # split occurred within quotes or a jinja2 block of some kind + for (itemidx, item) in enumerate(items): + + # we split on spaces and newlines separately, so that we + # can tell which character we split on for reassembly + # inside quotation characters + tokens = item.split(' ') + + line_continuation = False + for (idx, token) in enumerate(tokens): + + # Empty entries means we have subsequent spaces + # We want to hold onto them so we can reconstruct them later + if len(token) == 0 and idx != 0: + params[-1] += ' ' + continue + + # if we hit a line continuation character, but + # we're not inside quotes, ignore it and continue + # on to the next token while setting a flag + if token == '\\' and not inside_quotes: + line_continuation = True + continue + + # store the previous quoting state for checking later + was_inside_quotes = inside_quotes + quote_char = _get_quote_state(token, quote_char) + inside_quotes = quote_char is not None + + # multiple conditions may append a token to the list of params, + # so we keep track with this flag to make sure it only happens once + # append means add to the end of the list, don't append means concatenate + # it to the end of the last token + appended = False + + # if we're inside quotes now, but weren't before, append the token + # to the end of the list, since we'll tack on more to it later + # otherwise, if we're inside any jinja2 block, inside quotes, or we were + # inside quotes (but aren't now) concat this token to the last param + if inside_quotes and not was_inside_quotes and not (print_depth or block_depth or comment_depth): + params.append(token) + appended = True + elif print_depth or block_depth or comment_depth or inside_quotes or was_inside_quotes: + if idx == 0 and was_inside_quotes: + params[-1] = "%s%s" % (params[-1], token) + elif len(tokens) > 1: + spacer = '' + if idx > 0: + spacer = ' ' + params[-1] = "%s%s%s" % (params[-1], spacer, token) + else: + params[-1] = "%s\n%s" % (params[-1], token) + appended = True + + # if the number of paired block tags is not the same, the depth has changed, so we calculate that here + # and may append the current token to the params (if we haven't previously done so) + prev_print_depth = print_depth + print_depth = _count_jinja2_blocks(token, print_depth, "{{", "}}") + if print_depth != prev_print_depth and not appended: + params.append(token) + appended = True + + prev_block_depth = block_depth + block_depth = _count_jinja2_blocks(token, block_depth, "{%", "%}") + if block_depth != prev_block_depth and not appended: + params.append(token) + appended = True + + prev_comment_depth = comment_depth + comment_depth = _count_jinja2_blocks(token, comment_depth, "{#", "#}") + if comment_depth != prev_comment_depth and not appended: + params.append(token) + appended = True + + # finally, if we're at zero depth for all blocks and not inside quotes, and have not + # yet appended anything to the list of params, we do so now + if not (print_depth or block_depth or comment_depth) and not inside_quotes and not appended and token != '': + params.append(token) + + # if this was the last token in the list, and we have more than + # one item (meaning we split on newlines), add a newline back here + # to preserve the original structure + if len(items) > 1 and itemidx != len(items) - 1 and not line_continuation: + params[-1] += '\n' + + # always clear the line continuation flag + line_continuation = False + + # If we're done and things are not at zero depth or we're still inside quotes, + # raise an error to indicate that the args were unbalanced + if print_depth or block_depth or comment_depth or inside_quotes: + raise AnsibleParserError(u"failed at splitting arguments, either an unbalanced jinja2 block or quotes: {0}".format(args)) + + return params diff --git a/lib/ansible/parsing/utils/__init__.py b/lib/ansible/parsing/utils/__init__.py new file mode 100644 index 0000000..ae8ccff --- /dev/null +++ b/lib/ansible/parsing/utils/__init__.py @@ -0,0 +1,20 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type diff --git a/lib/ansible/parsing/utils/addresses.py b/lib/ansible/parsing/utils/addresses.py new file mode 100644 index 0000000..0096af4 --- /dev/null +++ b/lib/ansible/parsing/utils/addresses.py @@ -0,0 +1,216 @@ +# Copyright 2015 Abhijit Menon-Sen <ams@2ndQuadrant.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +from ansible.errors import AnsibleParserError, AnsibleError + +# Components that match a numeric or alphanumeric begin:end or begin:end:step +# range expression inside square brackets. + +numeric_range = r''' + \[ + (?:[0-9]+:[0-9]+) # numeric begin:end + (?::[0-9]+)? # numeric :step (optional) + \] +''' + +hexadecimal_range = r''' + \[ + (?:[0-9a-f]+:[0-9a-f]+) # hexadecimal begin:end + (?::[0-9]+)? # numeric :step (optional) + \] +''' + +alphanumeric_range = r''' + \[ + (?: + [a-z]:[a-z]| # one-char alphabetic range + [0-9]+:[0-9]+ # ...or a numeric one + ) + (?::[0-9]+)? # numeric :step (optional) + \] +''' + +# Components that match a 16-bit portion of an IPv6 address in hexadecimal +# notation (0..ffff) or an 8-bit portion of an IPv4 address in decimal notation +# (0..255) or an [x:y(:z)] numeric range. + +ipv6_component = r''' + (?: + [0-9a-f]{{1,4}}| # 0..ffff + {range} # or a numeric range + ) +'''.format(range=hexadecimal_range) + +ipv4_component = r''' + (?: + [01]?[0-9]{{1,2}}| # 0..199 + 2[0-4][0-9]| # 200..249 + 25[0-5]| # 250..255 + {range} # or a numeric range + ) +'''.format(range=numeric_range) + +# A hostname label, e.g. 'foo' in 'foo.example.com'. Consists of alphanumeric +# characters plus dashes (and underscores) or valid ranges. The label may not +# start or end with a hyphen or an underscore. This is interpolated into the +# hostname pattern below. We don't try to enforce the 63-char length limit. + +label = r''' + (?:[\w]|{range}) # Starts with an alphanumeric or a range + (?:[\w_-]|{range})* # Then zero or more of the same or [_-] + (?<![_-]) # ...as long as it didn't end with [_-] +'''.format(range=alphanumeric_range) + +patterns = { + # This matches a square-bracketed expression with a port specification. What + # is inside the square brackets is validated later. + + 'bracketed_hostport': re.compile( + r'''^ + \[(.+)\] # [host identifier] + :([0-9]+) # :port number + $ + ''', re.X + ), + + # This matches a bare IPv4 address or hostname (or host pattern including + # [x:y(:z)] ranges) with a port specification. + + 'hostport': re.compile( + r'''^ + ((?: # We want to match: + [^:\[\]] # (a non-range character + | # ...or... + \[[^\]]*\] # a complete bracketed expression) + )*) # repeated as many times as possible + :([0-9]+) # followed by a port number + $ + ''', re.X + ), + + # This matches an IPv4 address, but also permits range expressions. + + 'ipv4': re.compile( + r'''^ + (?:{i4}\.){{3}}{i4} # Three parts followed by dots plus one + $ + '''.format(i4=ipv4_component), re.X | re.I + ), + + # This matches an IPv6 address, but also permits range expressions. + # + # This expression looks complex, but it really only spells out the various + # combinations in which the basic unit of an IPv6 address (0..ffff) can be + # written, from :: to 1:2:3:4:5:6:7:8, plus the IPv4-in-IPv6 variants such + # as ::ffff:192.0.2.3. + # + # Note that we can't just use ipaddress.ip_address() because we also have to + # accept ranges in place of each component. + + 'ipv6': re.compile( + r'''^ + (?:{0}:){{7}}{0}| # uncompressed: 1:2:3:4:5:6:7:8 + (?:{0}:){{1,6}}:| # compressed variants, which are all + (?:{0}:)(?::{0}){{1,6}}| # a::b for various lengths of a,b + (?:{0}:){{2}}(?::{0}){{1,5}}| + (?:{0}:){{3}}(?::{0}){{1,4}}| + (?:{0}:){{4}}(?::{0}){{1,3}}| + (?:{0}:){{5}}(?::{0}){{1,2}}| + (?:{0}:){{6}}(?::{0})| # ...all with 2 <= a+b <= 7 + :(?::{0}){{1,6}}| # ::ffff(:ffff...) + {0}?::| # ffff::, :: + # ipv4-in-ipv6 variants + (?:0:){{6}}(?:{0}\.){{3}}{0}| + ::(?:ffff:)?(?:{0}\.){{3}}{0}| + (?:0:){{5}}ffff:(?:{0}\.){{3}}{0} + $ + '''.format(ipv6_component), re.X | re.I + ), + + # This matches a hostname or host pattern including [x:y(:z)] ranges. + # + # We roughly follow DNS rules here, but also allow ranges (and underscores). + # In the past, no systematic rules were enforced about inventory hostnames, + # but the parsing context (e.g. shlex.split(), fnmatch.fnmatch()) excluded + # various metacharacters anyway. + # + # We don't enforce DNS length restrictions here (63 characters per label, + # 253 characters total) or make any attempt to process IDNs. + + 'hostname': re.compile( + r'''^ + {label} # We must have at least one label + (?:\.{label})* # Followed by zero or more .labels + $ + '''.format(label=label), re.X | re.I | re.UNICODE + ), + +} + + +def parse_address(address, allow_ranges=False): + """ + Takes a string and returns a (host, port) tuple. If the host is None, then + the string could not be parsed as a host identifier with an optional port + specification. If the port is None, then no port was specified. + + The host identifier may be a hostname (qualified or not), an IPv4 address, + or an IPv6 address. If allow_ranges is True, then any of those may contain + [x:y] range specifications, e.g. foo[1:3] or foo[0:5]-bar[x-z]. + + The port number is an optional :NN suffix on an IPv4 address or host name, + or a mandatory :NN suffix on any square-bracketed expression: IPv6 address, + IPv4 address, or host name. (This means the only way to specify a port for + an IPv6 address is to enclose it in square brackets.) + """ + + # First, we extract the port number if one is specified. + + port = None + for matching in ['bracketed_hostport', 'hostport']: + m = patterns[matching].match(address) + if m: + (address, port) = m.groups() + port = int(port) + continue + + # What we're left with now must be an IPv4 or IPv6 address, possibly with + # numeric ranges, or a hostname with alphanumeric ranges. + + host = None + for matching in ['ipv4', 'ipv6', 'hostname']: + m = patterns[matching].match(address) + if m: + host = address + continue + + # If it isn't any of the above, we don't understand it. + if not host: + raise AnsibleError("Not a valid network hostname: %s" % address) + + # If we get to this point, we know that any included ranges are valid. + # If the caller is prepared to handle them, all is well. + # Otherwise we treat it as a parse failure. + if not allow_ranges and '[' in host: + raise AnsibleParserError("Detected range in host but was asked to ignore ranges") + + return (host, port) diff --git a/lib/ansible/parsing/utils/jsonify.py b/lib/ansible/parsing/utils/jsonify.py new file mode 100644 index 0000000..19ebc56 --- /dev/null +++ b/lib/ansible/parsing/utils/jsonify.py @@ -0,0 +1,38 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + + +def jsonify(result, format=False): + ''' format JSON output (uncompressed or uncompressed) ''' + + if result is None: + return "{}" + + indent = None + if format: + indent = 4 + + try: + return json.dumps(result, sort_keys=True, indent=indent, ensure_ascii=False) + except UnicodeDecodeError: + return json.dumps(result, sort_keys=True, indent=indent) diff --git a/lib/ansible/parsing/utils/yaml.py b/lib/ansible/parsing/utils/yaml.py new file mode 100644 index 0000000..91e37f9 --- /dev/null +++ b/lib/ansible/parsing/utils/yaml.py @@ -0,0 +1,84 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# Copyright: (c) 2017, Ansible Project +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +from yaml import YAMLError + +from ansible.errors import AnsibleParserError +from ansible.errors.yaml_strings import YAML_SYNTAX_ERROR +from ansible.module_utils._text import to_native +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject +from ansible.parsing.ajson import AnsibleJSONDecoder + + +__all__ = ('from_yaml',) + + +def _handle_error(json_exc, yaml_exc, file_name, show_content): + ''' + Optionally constructs an object (AnsibleBaseYAMLObject) to encapsulate the + file name/position where a YAML exception occurred, and raises an AnsibleParserError + to display the syntax exception information. + ''' + + # if the YAML exception contains a problem mark, use it to construct + # an object the error class can use to display the faulty line + err_obj = None + if hasattr(yaml_exc, 'problem_mark'): + err_obj = AnsibleBaseYAMLObject() + err_obj.ansible_pos = (file_name, yaml_exc.problem_mark.line + 1, yaml_exc.problem_mark.column + 1) + + n_yaml_syntax_error = YAML_SYNTAX_ERROR % to_native(getattr(yaml_exc, 'problem', u'')) + n_err_msg = 'We were unable to read either as JSON nor YAML, these are the errors we got from each:\n' \ + 'JSON: %s\n\n%s' % (to_native(json_exc), n_yaml_syntax_error) + + raise AnsibleParserError(n_err_msg, obj=err_obj, show_content=show_content, orig_exc=yaml_exc) + + +def _safe_load(stream, file_name=None, vault_secrets=None): + ''' Implements yaml.safe_load(), except using our custom loader class. ''' + + loader = AnsibleLoader(stream, file_name, vault_secrets) + try: + return loader.get_single_data() + finally: + try: + loader.dispose() + except AttributeError: + pass # older versions of yaml don't have dispose function, ignore + + +def from_yaml(data, file_name='<string>', show_content=True, vault_secrets=None, json_only=False): + ''' + Creates a python datastructure from the given data, which can be either + a JSON or YAML string. + ''' + new_data = None + + try: + # in case we have to deal with vaults + AnsibleJSONDecoder.set_secrets(vault_secrets) + + # we first try to load this data as JSON. + # Fixes issues with extra vars json strings not being parsed correctly by the yaml parser + new_data = json.loads(data, cls=AnsibleJSONDecoder) + except Exception as json_exc: + + if json_only: + raise AnsibleParserError(to_native(json_exc), orig_exc=json_exc) + + # must not be JSON, let the rest try + try: + new_data = _safe_load(data, file_name=file_name, vault_secrets=vault_secrets) + except YAMLError as yaml_exc: + _handle_error(json_exc, yaml_exc, file_name, show_content) + + return new_data diff --git a/lib/ansible/parsing/vault/__init__.py b/lib/ansible/parsing/vault/__init__.py new file mode 100644 index 0000000..8ac22d4 --- /dev/null +++ b/lib/ansible/parsing/vault/__init__.py @@ -0,0 +1,1289 @@ +# (c) 2014, James Tanner <tanner.jc@gmail.com> +# (c) 2016, Adrian Likins <alikins@redhat.com> +# (c) 2016 Toshio Kuratomi <tkuratomi@ansible.com> +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import errno +import fcntl +import os +import random +import shlex +import shutil +import subprocess +import sys +import tempfile +import warnings + +from binascii import hexlify +from binascii import unhexlify +from binascii import Error as BinasciiError + +HAS_CRYPTOGRAPHY = False +CRYPTOGRAPHY_BACKEND = None +try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from cryptography.exceptions import InvalidSignature + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, padding + from cryptography.hazmat.primitives.hmac import HMAC + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + from cryptography.hazmat.primitives.ciphers import ( + Cipher as C_Cipher, algorithms, modes + ) + CRYPTOGRAPHY_BACKEND = default_backend() + HAS_CRYPTOGRAPHY = True +except ImportError: + pass + +from ansible.errors import AnsibleError, AnsibleAssertionError +from ansible import constants as C +from ansible.module_utils.six import binary_type +from ansible.module_utils._text import to_bytes, to_text, to_native +from ansible.utils.display import Display +from ansible.utils.path import makedirs_safe, unfrackpath + +display = Display() + + +b_HEADER = b'$ANSIBLE_VAULT' +CIPHER_WHITELIST = frozenset((u'AES256',)) +CIPHER_WRITE_WHITELIST = frozenset((u'AES256',)) +# See also CIPHER_MAPPING at the bottom of the file which maps cipher strings +# (used in VaultFile header) to a cipher class + +NEED_CRYPTO_LIBRARY = "ansible-vault requires the cryptography library in order to function" + + +class AnsibleVaultError(AnsibleError): + pass + + +class AnsibleVaultPasswordError(AnsibleVaultError): + pass + + +class AnsibleVaultFormatError(AnsibleError): + pass + + +def is_encrypted(data): + """ Test if this is vault encrypted data blob + + :arg data: a byte or text string to test whether it is recognized as vault + encrypted data + :returns: True if it is recognized. Otherwise, False. + """ + try: + # Make sure we have a byte string and that it only contains ascii + # bytes. + b_data = to_bytes(to_text(data, encoding='ascii', errors='strict', nonstring='strict'), encoding='ascii', errors='strict') + except (UnicodeError, TypeError): + # The vault format is pure ascii so if we failed to encode to bytes + # via ascii we know that this is not vault data. + # Similarly, if it's not a string, it's not vault data + return False + + if b_data.startswith(b_HEADER): + return True + return False + + +def is_encrypted_file(file_obj, start_pos=0, count=-1): + """Test if the contents of a file obj are a vault encrypted data blob. + + :arg file_obj: A file object that will be read from. + :kwarg start_pos: A byte offset in the file to start reading the header + from. Defaults to 0, the beginning of the file. + :kwarg count: Read up to this number of bytes from the file to determine + if it looks like encrypted vault data. The default is -1, read to the + end of file. + :returns: True if the file looks like a vault file. Otherwise, False. + """ + # read the header and reset the file stream to where it started + current_position = file_obj.tell() + try: + file_obj.seek(start_pos) + return is_encrypted(file_obj.read(count)) + + finally: + file_obj.seek(current_position) + + +def _parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None): + + b_tmpdata = b_vaulttext_envelope.splitlines() + b_tmpheader = b_tmpdata[0].strip().split(b';') + + b_version = b_tmpheader[1].strip() + cipher_name = to_text(b_tmpheader[2].strip()) + vault_id = default_vault_id + + # Only attempt to find vault_id if the vault file is version 1.2 or newer + # if self.b_version == b'1.2': + if len(b_tmpheader) >= 4: + vault_id = to_text(b_tmpheader[3].strip()) + + b_ciphertext = b''.join(b_tmpdata[1:]) + + return b_ciphertext, b_version, cipher_name, vault_id + + +def parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None, filename=None): + """Parse the vaulttext envelope + + When data is saved, it has a header prepended and is formatted into 80 + character lines. This method extracts the information from the header + and then removes the header and the inserted newlines. The string returned + is suitable for processing by the Cipher classes. + + :arg b_vaulttext: byte str containing the data from a save file + :kwarg default_vault_id: The vault_id name to use if the vaulttext does not provide one. + :kwarg filename: The filename that the data came from. This is only + used to make better error messages in case the data cannot be + decrypted. This is optional. + :returns: A tuple of byte str of the vaulttext suitable to pass to parse_vaultext, + a byte str of the vault format version, + the name of the cipher used, and the vault_id. + :raises: AnsibleVaultFormatError: if the vaulttext_envelope format is invalid + """ + # used by decrypt + default_vault_id = default_vault_id or C.DEFAULT_VAULT_IDENTITY + + try: + return _parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id) + except Exception as exc: + msg = "Vault envelope format error" + if filename: + msg += ' in %s' % (filename) + msg += ': %s' % exc + raise AnsibleVaultFormatError(msg) + + +def format_vaulttext_envelope(b_ciphertext, cipher_name, version=None, vault_id=None): + """ Add header and format to 80 columns + + :arg b_ciphertext: the encrypted and hexlified data as a byte string + :arg cipher_name: unicode cipher name (for ex, u'AES256') + :arg version: unicode vault version (for ex, '1.2'). Optional ('1.1' is default) + :arg vault_id: unicode vault identifier. If provided, the version will be bumped to 1.2. + :returns: a byte str that should be dumped into a file. It's + formatted to 80 char columns and has the header prepended + """ + + if not cipher_name: + raise AnsibleError("the cipher must be set before adding a header") + + version = version or '1.1' + + # If we specify a vault_id, use format version 1.2. For no vault_id, stick to 1.1 + if vault_id and vault_id != u'default': + version = '1.2' + + b_version = to_bytes(version, 'utf-8', errors='strict') + b_vault_id = to_bytes(vault_id, 'utf-8', errors='strict') + b_cipher_name = to_bytes(cipher_name, 'utf-8', errors='strict') + + header_parts = [b_HEADER, + b_version, + b_cipher_name] + + if b_version == b'1.2' and b_vault_id: + header_parts.append(b_vault_id) + + header = b';'.join(header_parts) + + b_vaulttext = [header] + b_vaulttext += [b_ciphertext[i:i + 80] for i in range(0, len(b_ciphertext), 80)] + b_vaulttext += [b''] + b_vaulttext = b'\n'.join(b_vaulttext) + + return b_vaulttext + + +def _unhexlify(b_data): + try: + return unhexlify(b_data) + except (BinasciiError, TypeError) as exc: + raise AnsibleVaultFormatError('Vault format unhexlify error: %s' % exc) + + +def _parse_vaulttext(b_vaulttext): + b_vaulttext = _unhexlify(b_vaulttext) + b_salt, b_crypted_hmac, b_ciphertext = b_vaulttext.split(b"\n", 2) + b_salt = _unhexlify(b_salt) + b_ciphertext = _unhexlify(b_ciphertext) + + return b_ciphertext, b_salt, b_crypted_hmac + + +def parse_vaulttext(b_vaulttext): + """Parse the vaulttext + + :arg b_vaulttext: byte str containing the vaulttext (ciphertext, salt, crypted_hmac) + :returns: A tuple of byte str of the ciphertext suitable for passing to a + Cipher class's decrypt() function, a byte str of the salt, + and a byte str of the crypted_hmac + :raises: AnsibleVaultFormatError: if the vaulttext format is invalid + """ + # SPLIT SALT, DIGEST, AND DATA + try: + return _parse_vaulttext(b_vaulttext) + except AnsibleVaultFormatError: + raise + except Exception as exc: + msg = "Vault vaulttext format error: %s" % exc + raise AnsibleVaultFormatError(msg) + + +def verify_secret_is_not_empty(secret, msg=None): + '''Check the secret against minimal requirements. + + Raises: AnsibleVaultPasswordError if the password does not meet requirements. + + Currently, only requirement is that the password is not None or an empty string. + ''' + msg = msg or 'Invalid vault password was provided' + if not secret: + raise AnsibleVaultPasswordError(msg) + + +class VaultSecret: + '''Opaque/abstract objects for a single vault secret. ie, a password or a key.''' + + def __init__(self, _bytes=None): + # FIXME: ? that seems wrong... Unset etc? + self._bytes = _bytes + + @property + def bytes(self): + '''The secret as a bytestring. + + Sub classes that store text types will need to override to encode the text to bytes. + ''' + return self._bytes + + def load(self): + return self._bytes + + +class PromptVaultSecret(VaultSecret): + default_prompt_formats = ["Vault password (%s): "] + + def __init__(self, _bytes=None, vault_id=None, prompt_formats=None): + super(PromptVaultSecret, self).__init__(_bytes=_bytes) + self.vault_id = vault_id + + if prompt_formats is None: + self.prompt_formats = self.default_prompt_formats + else: + self.prompt_formats = prompt_formats + + @property + def bytes(self): + return self._bytes + + def load(self): + self._bytes = self.ask_vault_passwords() + + def ask_vault_passwords(self): + b_vault_passwords = [] + + for prompt_format in self.prompt_formats: + prompt = prompt_format % {'vault_id': self.vault_id} + try: + vault_pass = display.prompt(prompt, private=True) + except EOFError: + raise AnsibleVaultError('EOFError (ctrl-d) on prompt for (%s)' % self.vault_id) + + verify_secret_is_not_empty(vault_pass) + + b_vault_pass = to_bytes(vault_pass, errors='strict', nonstring='simplerepr').strip() + b_vault_passwords.append(b_vault_pass) + + # Make sure the passwords match by comparing them all to the first password + for b_vault_password in b_vault_passwords: + self.confirm(b_vault_passwords[0], b_vault_password) + + if b_vault_passwords: + return b_vault_passwords[0] + + return None + + def confirm(self, b_vault_pass_1, b_vault_pass_2): + # enforce no newline chars at the end of passwords + + if b_vault_pass_1 != b_vault_pass_2: + # FIXME: more specific exception + raise AnsibleError("Passwords do not match") + + +def script_is_client(filename): + '''Determine if a vault secret script is a client script that can be given --vault-id args''' + + # if password script is 'something-client' or 'something-client.[sh|py|rb|etc]' + # script_name can still have '.' or could be entire filename if there is no ext + script_name, dummy = os.path.splitext(filename) + + # TODO: for now, this is entirely based on filename + if script_name.endswith('-client'): + return True + + return False + + +def get_file_vault_secret(filename=None, vault_id=None, encoding=None, loader=None): + ''' Get secret from file content or execute file and get secret from stdout ''' + + # we unfrack but not follow the full path/context to possible vault script + # so when the script uses 'adjacent' file for configuration or similar + # it still works (as inventory scripts often also do). + # while files from --vault-password-file are already unfracked, other sources are not + this_path = unfrackpath(filename, follow=False) + if not os.path.exists(this_path): + raise AnsibleError("The vault password file %s was not found" % this_path) + + # it is a script? + if loader.is_executable(this_path): + + if script_is_client(filename): + # this is special script type that handles vault ids + display.vvvv(u'The vault password file %s is a client script.' % to_text(this_path)) + # TODO: pass vault_id_name to script via cli + return ClientScriptVaultSecret(filename=this_path, vault_id=vault_id, encoding=encoding, loader=loader) + + # just a plain vault password script. No args, returns a byte array + return ScriptVaultSecret(filename=this_path, encoding=encoding, loader=loader) + + return FileVaultSecret(filename=this_path, encoding=encoding, loader=loader) + + +# TODO: mv these classes to a separate file so we don't pollute vault with 'subprocess' etc +class FileVaultSecret(VaultSecret): + def __init__(self, filename=None, encoding=None, loader=None): + super(FileVaultSecret, self).__init__() + self.filename = filename + self.loader = loader + + self.encoding = encoding or 'utf8' + + # We could load from file here, but that is eventually a pain to test + self._bytes = None + self._text = None + + @property + def bytes(self): + if self._bytes: + return self._bytes + if self._text: + return self._text.encode(self.encoding) + return None + + def load(self): + self._bytes = self._read_file(self.filename) + + def _read_file(self, filename): + """ + Read a vault password from a file or if executable, execute the script and + retrieve password from STDOUT + """ + + # TODO: replace with use of self.loader + try: + with open(filename, "rb") as f: + vault_pass = f.read().strip() + except (OSError, IOError) as e: + raise AnsibleError("Could not read vault password file %s: %s" % (filename, e)) + + b_vault_data, dummy = self.loader._decrypt_if_vault_data(vault_pass, filename) + + vault_pass = b_vault_data.strip(b'\r\n') + + verify_secret_is_not_empty(vault_pass, + msg='Invalid vault password was provided from file (%s)' % filename) + + return vault_pass + + def __repr__(self): + if self.filename: + return "%s(filename='%s')" % (self.__class__.__name__, self.filename) + return "%s()" % (self.__class__.__name__) + + +class ScriptVaultSecret(FileVaultSecret): + def _read_file(self, filename): + if not self.loader.is_executable(filename): + raise AnsibleVaultError("The vault password script %s was not executable" % filename) + + command = self._build_command() + + stdout, stderr, p = self._run(command) + + self._check_results(stdout, stderr, p) + + vault_pass = stdout.strip(b'\r\n') + + empty_password_msg = 'Invalid vault password was provided from script (%s)' % filename + verify_secret_is_not_empty(vault_pass, msg=empty_password_msg) + + return vault_pass + + def _run(self, command): + try: + # STDERR not captured to make it easier for users to prompt for input in their scripts + p = subprocess.Popen(command, stdout=subprocess.PIPE) + except OSError as e: + msg_format = "Problem running vault password script %s (%s)." \ + " If this is not a script, remove the executable bit from the file." + msg = msg_format % (self.filename, e) + + raise AnsibleError(msg) + + stdout, stderr = p.communicate() + return stdout, stderr, p + + def _check_results(self, stdout, stderr, popen): + if popen.returncode != 0: + raise AnsibleError("Vault password script %s returned non-zero (%s): %s" % + (self.filename, popen.returncode, stderr)) + + def _build_command(self): + return [self.filename] + + +class ClientScriptVaultSecret(ScriptVaultSecret): + VAULT_ID_UNKNOWN_RC = 2 + + def __init__(self, filename=None, encoding=None, loader=None, vault_id=None): + super(ClientScriptVaultSecret, self).__init__(filename=filename, + encoding=encoding, + loader=loader) + self._vault_id = vault_id + display.vvvv(u'Executing vault password client script: %s --vault-id %s' % (to_text(filename), to_text(vault_id))) + + def _run(self, command): + try: + p = subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + except OSError as e: + msg_format = "Problem running vault password client script %s (%s)." \ + " If this is not a script, remove the executable bit from the file." + msg = msg_format % (self.filename, e) + + raise AnsibleError(msg) + + stdout, stderr = p.communicate() + return stdout, stderr, p + + def _check_results(self, stdout, stderr, popen): + if popen.returncode == self.VAULT_ID_UNKNOWN_RC: + raise AnsibleError('Vault password client script %s did not find a secret for vault-id=%s: %s' % + (self.filename, self._vault_id, stderr)) + + if popen.returncode != 0: + raise AnsibleError("Vault password client script %s returned non-zero (%s) when getting secret for vault-id=%s: %s" % + (self.filename, popen.returncode, self._vault_id, stderr)) + + def _build_command(self): + command = [self.filename] + if self._vault_id: + command.extend(['--vault-id', self._vault_id]) + + return command + + def __repr__(self): + if self.filename: + return "%s(filename='%s', vault_id='%s')" % \ + (self.__class__.__name__, self.filename, self._vault_id) + return "%s()" % (self.__class__.__name__) + + +def match_secrets(secrets, target_vault_ids): + '''Find all VaultSecret objects that are mapped to any of the target_vault_ids in secrets''' + if not secrets: + return [] + + matches = [(vault_id, secret) for vault_id, secret in secrets if vault_id in target_vault_ids] + return matches + + +def match_best_secret(secrets, target_vault_ids): + '''Find the best secret from secrets that matches target_vault_ids + + Since secrets should be ordered so the early secrets are 'better' than later ones, this + just finds all the matches, then returns the first secret''' + matches = match_secrets(secrets, target_vault_ids) + if matches: + return matches[0] + # raise exception? + return None + + +def match_encrypt_vault_id_secret(secrets, encrypt_vault_id=None): + # See if the --encrypt-vault-id matches a vault-id + display.vvvv(u'encrypt_vault_id=%s' % to_text(encrypt_vault_id)) + + if encrypt_vault_id is None: + raise AnsibleError('match_encrypt_vault_id_secret requires a non None encrypt_vault_id') + + encrypt_vault_id_matchers = [encrypt_vault_id] + encrypt_secret = match_best_secret(secrets, encrypt_vault_id_matchers) + + # return the best match for --encrypt-vault-id + if encrypt_secret: + return encrypt_secret + + # If we specified a encrypt_vault_id and we couldn't find it, dont + # fallback to using the first/best secret + raise AnsibleVaultError('Did not find a match for --encrypt-vault-id=%s in the known vault-ids %s' % (encrypt_vault_id, + [_v for _v, _vs in secrets])) + + +def match_encrypt_secret(secrets, encrypt_vault_id=None): + '''Find the best/first/only secret in secrets to use for encrypting''' + + display.vvvv(u'encrypt_vault_id=%s' % to_text(encrypt_vault_id)) + # See if the --encrypt-vault-id matches a vault-id + if encrypt_vault_id: + return match_encrypt_vault_id_secret(secrets, + encrypt_vault_id=encrypt_vault_id) + + # Find the best/first secret from secrets since we didnt specify otherwise + # ie, consider all of the available secrets as matches + _vault_id_matchers = [_vault_id for _vault_id, dummy in secrets] + best_secret = match_best_secret(secrets, _vault_id_matchers) + + # can be empty list sans any tuple + return best_secret + + +class VaultLib: + def __init__(self, secrets=None): + self.secrets = secrets or [] + self.cipher_name = None + self.b_version = b'1.2' + + @staticmethod + def is_encrypted(vaulttext): + return is_encrypted(vaulttext) + + def encrypt(self, plaintext, secret=None, vault_id=None, salt=None): + """Vault encrypt a piece of data. + + :arg plaintext: a text or byte string to encrypt. + :returns: a utf-8 encoded byte str of encrypted data. The string + contains a header identifying this as vault encrypted data and + formatted to newline terminated lines of 80 characters. This is + suitable for dumping as is to a vault file. + + If the string passed in is a text string, it will be encoded to UTF-8 + before encryption. + """ + + if secret is None: + if self.secrets: + dummy, secret = match_encrypt_secret(self.secrets) + else: + raise AnsibleVaultError("A vault password must be specified to encrypt data") + + b_plaintext = to_bytes(plaintext, errors='surrogate_or_strict') + + if is_encrypted(b_plaintext): + raise AnsibleError("input is already encrypted") + + if not self.cipher_name or self.cipher_name not in CIPHER_WRITE_WHITELIST: + self.cipher_name = u"AES256" + + try: + this_cipher = CIPHER_MAPPING[self.cipher_name]() + except KeyError: + raise AnsibleError(u"{0} cipher could not be found".format(self.cipher_name)) + + # encrypt data + if vault_id: + display.vvvvv(u'Encrypting with vault_id "%s" and vault secret %s' % (to_text(vault_id), to_text(secret))) + else: + display.vvvvv(u'Encrypting without a vault_id using vault secret %s' % to_text(secret)) + + b_ciphertext = this_cipher.encrypt(b_plaintext, secret, salt) + + # format the data for output to the file + b_vaulttext = format_vaulttext_envelope(b_ciphertext, + self.cipher_name, + vault_id=vault_id) + return b_vaulttext + + def decrypt(self, vaulttext, filename=None, obj=None): + '''Decrypt a piece of vault encrypted data. + + :arg vaulttext: a string to decrypt. Since vault encrypted data is an + ascii text format this can be either a byte str or unicode string. + :kwarg filename: a filename that the data came from. This is only + used to make better error messages in case the data cannot be + decrypted. + :returns: a byte string containing the decrypted data and the vault-id that was used + + ''' + plaintext, vault_id, vault_secret = self.decrypt_and_get_vault_id(vaulttext, filename=filename, obj=obj) + return plaintext + + def decrypt_and_get_vault_id(self, vaulttext, filename=None, obj=None): + """Decrypt a piece of vault encrypted data. + + :arg vaulttext: a string to decrypt. Since vault encrypted data is an + ascii text format this can be either a byte str or unicode string. + :kwarg filename: a filename that the data came from. This is only + used to make better error messages in case the data cannot be + decrypted. + :returns: a byte string containing the decrypted data and the vault-id vault-secret that was used + + """ + b_vaulttext = to_bytes(vaulttext, errors='strict', encoding='utf-8') + + if self.secrets is None: + raise AnsibleVaultError("A vault password must be specified to decrypt data") + + if not is_encrypted(b_vaulttext): + msg = "input is not vault encrypted data. " + if filename: + msg += "%s is not a vault encrypted file" % to_native(filename) + raise AnsibleError(msg) + + b_vaulttext, dummy, cipher_name, vault_id = parse_vaulttext_envelope(b_vaulttext, filename=filename) + + # create the cipher object, note that the cipher used for decrypt can + # be different than the cipher used for encrypt + if cipher_name in CIPHER_WHITELIST: + this_cipher = CIPHER_MAPPING[cipher_name]() + else: + raise AnsibleError("{0} cipher could not be found".format(cipher_name)) + + b_plaintext = None + + if not self.secrets: + raise AnsibleVaultError('Attempting to decrypt but no vault secrets found') + + # WARNING: Currently, the vault id is not required to match the vault id in the vault blob to + # decrypt a vault properly. The vault id in the vault blob is not part of the encrypted + # or signed vault payload. There is no cryptographic checking/verification/validation of the + # vault blobs vault id. It can be tampered with and changed. The vault id is just a nick + # name to use to pick the best secret and provide some ux/ui info. + + # iterate over all the applicable secrets (all of them by default) until one works... + # if we specify a vault_id, only the corresponding vault secret is checked and + # we check it first. + + vault_id_matchers = [] + vault_id_used = None + vault_secret_used = None + + if vault_id: + display.vvvvv(u'Found a vault_id (%s) in the vaulttext' % to_text(vault_id)) + vault_id_matchers.append(vault_id) + _matches = match_secrets(self.secrets, vault_id_matchers) + if _matches: + display.vvvvv(u'We have a secret associated with vault id (%s), will try to use to decrypt %s' % (to_text(vault_id), to_text(filename))) + else: + display.vvvvv(u'Found a vault_id (%s) in the vault text, but we do not have a associated secret (--vault-id)' % to_text(vault_id)) + + # Not adding the other secrets to vault_secret_ids enforces a match between the vault_id from the vault_text and + # the known vault secrets. + if not C.DEFAULT_VAULT_ID_MATCH: + # Add all of the known vault_ids as candidates for decrypting a vault. + vault_id_matchers.extend([_vault_id for _vault_id, _dummy in self.secrets if _vault_id != vault_id]) + + matched_secrets = match_secrets(self.secrets, vault_id_matchers) + + # for vault_secret_id in vault_secret_ids: + for vault_secret_id, vault_secret in matched_secrets: + display.vvvvv(u'Trying to use vault secret=(%s) id=%s to decrypt %s' % (to_text(vault_secret), to_text(vault_secret_id), to_text(filename))) + + try: + # secret = self.secrets[vault_secret_id] + display.vvvv(u'Trying secret %s for vault_id=%s' % (to_text(vault_secret), to_text(vault_secret_id))) + b_plaintext = this_cipher.decrypt(b_vaulttext, vault_secret) + if b_plaintext is not None: + vault_id_used = vault_secret_id + vault_secret_used = vault_secret + file_slug = '' + if filename: + file_slug = ' of "%s"' % filename + display.vvvvv( + u'Decrypt%s successful with secret=%s and vault_id=%s' % (to_text(file_slug), to_text(vault_secret), to_text(vault_secret_id)) + ) + break + except AnsibleVaultFormatError as exc: + exc.obj = obj + msg = u"There was a vault format error" + if filename: + msg += u' in %s' % (to_text(filename)) + msg += u': %s' % to_text(exc) + display.warning(msg, formatted=True) + raise + except AnsibleError as e: + display.vvvv(u'Tried to use the vault secret (%s) to decrypt (%s) but it failed. Error: %s' % + (to_text(vault_secret_id), to_text(filename), e)) + continue + else: + msg = "Decryption failed (no vault secrets were found that could decrypt)" + if filename: + msg += " on %s" % to_native(filename) + raise AnsibleVaultError(msg) + + if b_plaintext is None: + msg = "Decryption failed" + if filename: + msg += " on %s" % to_native(filename) + raise AnsibleError(msg) + + return b_plaintext, vault_id_used, vault_secret_used + + +class VaultEditor: + + def __init__(self, vault=None): + # TODO: it may be more useful to just make VaultSecrets and index of VaultLib objects... + self.vault = vault or VaultLib() + + # TODO: mv shred file stuff to it's own class + def _shred_file_custom(self, tmp_path): + """"Destroy a file, when shred (core-utils) is not available + + Unix `shred' destroys files "so that they can be recovered only with great difficulty with + specialised hardware, if at all". It is based on the method from the paper + "Secure Deletion of Data from Magnetic and Solid-State Memory", + Proceedings of the Sixth USENIX Security Symposium (San Jose, California, July 22-25, 1996). + + We do not go to that length to re-implement shred in Python; instead, overwriting with a block + of random data should suffice. + + See https://github.com/ansible/ansible/pull/13700 . + """ + + file_len = os.path.getsize(tmp_path) + + if file_len > 0: # avoid work when file was empty + max_chunk_len = min(1024 * 1024 * 2, file_len) + + passes = 3 + with open(tmp_path, "wb") as fh: + for _ in range(passes): + fh.seek(0, 0) + # get a random chunk of data, each pass with other length + chunk_len = random.randint(max_chunk_len // 2, max_chunk_len) + data = os.urandom(chunk_len) + + for _ in range(0, file_len // chunk_len): + fh.write(data) + fh.write(data[:file_len % chunk_len]) + + # FIXME remove this assert once we have unittests to check its accuracy + if fh.tell() != file_len: + raise AnsibleAssertionError() + + os.fsync(fh) + + def _shred_file(self, tmp_path): + """Securely destroy a decrypted file + + Note standard limitations of GNU shred apply (For flash, overwriting would have no effect + due to wear leveling; for other storage systems, the async kernel->filesystem->disk calls never + guarantee data hits the disk; etc). Furthermore, if your tmp dirs is on tmpfs (ramdisks), + it is a non-issue. + + Nevertheless, some form of overwriting the data (instead of just removing the fs index entry) is + a good idea. If shred is not available (e.g. on windows, or no core-utils installed), fall back on + a custom shredding method. + """ + + if not os.path.isfile(tmp_path): + # file is already gone + return + + try: + r = subprocess.call(['shred', tmp_path]) + except (OSError, ValueError): + # shred is not available on this system, or some other error occurred. + # ValueError caught because macOS El Capitan is raising an + # exception big enough to hit a limit in python2-2.7.11 and below. + # Symptom is ValueError: insecure pickle when shred is not + # installed there. + r = 1 + + if r != 0: + # we could not successfully execute unix shred; therefore, do custom shred. + self._shred_file_custom(tmp_path) + + os.remove(tmp_path) + + def _edit_file_helper(self, filename, secret, existing_data=None, force_save=False, vault_id=None): + + # Create a tempfile + root, ext = os.path.splitext(os.path.realpath(filename)) + fd, tmp_path = tempfile.mkstemp(suffix=ext, dir=C.DEFAULT_LOCAL_TMP) + + cmd = self._editor_shell_command(tmp_path) + try: + if existing_data: + self.write_data(existing_data, fd, shred=False) + except Exception: + # if an error happens, destroy the decrypted file + self._shred_file(tmp_path) + raise + finally: + os.close(fd) + + try: + # drop the user into an editor on the tmp file + subprocess.call(cmd) + except Exception as e: + # if an error happens, destroy the decrypted file + self._shred_file(tmp_path) + raise AnsibleError('Unable to execute the command "%s": %s' % (' '.join(cmd), to_native(e))) + + b_tmpdata = self.read_data(tmp_path) + + # Do nothing if the content has not changed + if force_save or existing_data != b_tmpdata: + + # encrypt new data and write out to tmp + # An existing vaultfile will always be UTF-8, + # so decode to unicode here + b_ciphertext = self.vault.encrypt(b_tmpdata, secret, vault_id=vault_id) + self.write_data(b_ciphertext, tmp_path) + + # shuffle tmp file into place + self.shuffle_files(tmp_path, filename) + display.vvvvv(u'Saved edited file "%s" encrypted using %s and vault id "%s"' % (to_text(filename), to_text(secret), to_text(vault_id))) + + # always shred temp, jic + self._shred_file(tmp_path) + + def _real_path(self, filename): + # '-' is special to VaultEditor, dont expand it. + if filename == '-': + return filename + + real_path = os.path.realpath(filename) + return real_path + + def encrypt_bytes(self, b_plaintext, secret, vault_id=None): + + b_ciphertext = self.vault.encrypt(b_plaintext, secret, vault_id=vault_id) + + return b_ciphertext + + def encrypt_file(self, filename, secret, vault_id=None, output_file=None): + + # A file to be encrypted into a vaultfile could be any encoding + # so treat the contents as a byte string. + + # follow the symlink + filename = self._real_path(filename) + + b_plaintext = self.read_data(filename) + b_ciphertext = self.vault.encrypt(b_plaintext, secret, vault_id=vault_id) + self.write_data(b_ciphertext, output_file or filename) + + def decrypt_file(self, filename, output_file=None): + + # follow the symlink + filename = self._real_path(filename) + + ciphertext = self.read_data(filename) + + try: + plaintext = self.vault.decrypt(ciphertext, filename=filename) + except AnsibleError as e: + raise AnsibleError("%s for %s" % (to_native(e), to_native(filename))) + self.write_data(plaintext, output_file or filename, shred=False) + + def create_file(self, filename, secret, vault_id=None): + """ create a new encrypted file """ + + dirname = os.path.dirname(filename) + if dirname and not os.path.exists(dirname): + display.warning(u"%s does not exist, creating..." % to_text(dirname)) + makedirs_safe(dirname) + + # FIXME: If we can raise an error here, we can probably just make it + # behave like edit instead. + if os.path.isfile(filename): + raise AnsibleError("%s exists, please use 'edit' instead" % filename) + + self._edit_file_helper(filename, secret, vault_id=vault_id) + + def edit_file(self, filename): + vault_id_used = None + vault_secret_used = None + # follow the symlink + filename = self._real_path(filename) + + b_vaulttext = self.read_data(filename) + + # vault or yaml files are always utf8 + vaulttext = to_text(b_vaulttext) + + try: + # vaulttext gets converted back to bytes, but alas + # TODO: return the vault_id that worked? + plaintext, vault_id_used, vault_secret_used = self.vault.decrypt_and_get_vault_id(vaulttext) + except AnsibleError as e: + raise AnsibleError("%s for %s" % (to_native(e), to_native(filename))) + + # Figure out the vault id from the file, to select the right secret to re-encrypt it + # (duplicates parts of decrypt, but alas...) + dummy, dummy, cipher_name, vault_id = parse_vaulttext_envelope(b_vaulttext, filename=filename) + + # vault id here may not be the vault id actually used for decrypting + # as when the edited file has no vault-id but is decrypted by non-default id in secrets + # (vault_id=default, while a different vault-id decrypted) + + # we want to get rid of files encrypted with the AES cipher + force_save = (cipher_name not in CIPHER_WRITE_WHITELIST) + + # Keep the same vault-id (and version) as in the header + self._edit_file_helper(filename, vault_secret_used, existing_data=plaintext, force_save=force_save, vault_id=vault_id) + + def plaintext(self, filename): + + b_vaulttext = self.read_data(filename) + vaulttext = to_text(b_vaulttext) + + try: + plaintext = self.vault.decrypt(vaulttext, filename=filename) + return plaintext + except AnsibleError as e: + raise AnsibleVaultError("%s for %s" % (to_native(e), to_native(filename))) + + # FIXME/TODO: make this use VaultSecret + def rekey_file(self, filename, new_vault_secret, new_vault_id=None): + + # follow the symlink + filename = self._real_path(filename) + + prev = os.stat(filename) + b_vaulttext = self.read_data(filename) + vaulttext = to_text(b_vaulttext) + + display.vvvvv(u'Rekeying file "%s" to with new vault-id "%s" and vault secret %s' % + (to_text(filename), to_text(new_vault_id), to_text(new_vault_secret))) + try: + plaintext, vault_id_used, _dummy = self.vault.decrypt_and_get_vault_id(vaulttext) + except AnsibleError as e: + raise AnsibleError("%s for %s" % (to_native(e), to_native(filename))) + + # This is more or less an assert, see #18247 + if new_vault_secret is None: + raise AnsibleError('The value for the new_password to rekey %s with is not valid' % filename) + + # FIXME: VaultContext...? could rekey to a different vault_id in the same VaultSecrets + + # Need a new VaultLib because the new vault data can be a different + # vault lib format or cipher (for ex, when we migrate 1.0 style vault data to + # 1.1 style data we change the version and the cipher). This is where a VaultContext might help + + # the new vault will only be used for encrypting, so it doesn't need the vault secrets + # (we will pass one in directly to encrypt) + new_vault = VaultLib(secrets={}) + b_new_vaulttext = new_vault.encrypt(plaintext, new_vault_secret, vault_id=new_vault_id) + + self.write_data(b_new_vaulttext, filename) + + # preserve permissions + os.chmod(filename, prev.st_mode) + os.chown(filename, prev.st_uid, prev.st_gid) + + display.vvvvv(u'Rekeyed file "%s" (decrypted with vault id "%s") was encrypted with new vault-id "%s" and vault secret %s' % + (to_text(filename), to_text(vault_id_used), to_text(new_vault_id), to_text(new_vault_secret))) + + def read_data(self, filename): + + try: + if filename == '-': + data = sys.stdin.buffer.read() + else: + with open(filename, "rb") as fh: + data = fh.read() + except Exception as e: + msg = to_native(e) + if not msg: + msg = repr(e) + raise AnsibleError('Unable to read source file (%s): %s' % (to_native(filename), msg)) + + return data + + def write_data(self, data, thefile, shred=True, mode=0o600): + # TODO: add docstrings for arg types since this code is picky about that + """Write the data bytes to given path + + This is used to write a byte string to a file or stdout. It is used for + writing the results of vault encryption or decryption. It is used for + saving the ciphertext after encryption and it is also used for saving the + plaintext after decrypting a vault. The type of the 'data' arg should be bytes, + since in the plaintext case, the original contents can be of any text encoding + or arbitrary binary data. + + When used to write the result of vault encryption, the val of the 'data' arg + should be a utf-8 encoded byte string and not a text typ and not a text type.. + + When used to write the result of vault decryption, the val of the 'data' arg + should be a byte string and not a text type. + + :arg data: the byte string (bytes) data + :arg thefile: file descriptor or filename to save 'data' to. + :arg shred: if shred==True, make sure that the original data is first shredded so that is cannot be recovered. + :returns: None + """ + # FIXME: do we need this now? data_bytes should always be a utf-8 byte string + b_file_data = to_bytes(data, errors='strict') + + # check if we have a file descriptor instead of a path + is_fd = False + try: + is_fd = (isinstance(thefile, int) and fcntl.fcntl(thefile, fcntl.F_GETFD) != -1) + except Exception: + pass + + if is_fd: + # if passed descriptor, use that to ensure secure access, otherwise it is a string. + # assumes the fd is securely opened by caller (mkstemp) + os.ftruncate(thefile, 0) + os.write(thefile, b_file_data) + elif thefile == '-': + # get a ref to either sys.stdout.buffer for py3 or plain old sys.stdout for py2 + # We need sys.stdout.buffer on py3 so we can write bytes to it since the plaintext + # of the vaulted object could be anything/binary/etc + output = getattr(sys.stdout, 'buffer', sys.stdout) + output.write(b_file_data) + else: + # file names are insecure and prone to race conditions, so remove and create securely + if os.path.isfile(thefile): + if shred: + self._shred_file(thefile) + else: + os.remove(thefile) + + # when setting new umask, we get previous as return + current_umask = os.umask(0o077) + try: + try: + # create file with secure permissions + fd = os.open(thefile, os.O_CREAT | os.O_EXCL | os.O_RDWR | os.O_TRUNC, mode) + except OSError as ose: + # Want to catch FileExistsError, which doesn't exist in Python 2, so catch OSError + # and compare the error number to get equivalent behavior in Python 2/3 + if ose.errno == errno.EEXIST: + raise AnsibleError('Vault file got recreated while we were operating on it: %s' % to_native(ose)) + + raise AnsibleError('Problem creating temporary vault file: %s' % to_native(ose)) + + try: + # now write to the file and ensure ours is only data in it + os.ftruncate(fd, 0) + os.write(fd, b_file_data) + except OSError as e: + raise AnsibleError('Unable to write to temporary vault file: %s' % to_native(e)) + finally: + # Make sure the file descriptor is always closed and reset umask + os.close(fd) + finally: + os.umask(current_umask) + + def shuffle_files(self, src, dest): + prev = None + # overwrite dest with src + if os.path.isfile(dest): + prev = os.stat(dest) + # old file 'dest' was encrypted, no need to _shred_file + os.remove(dest) + shutil.move(src, dest) + + # reset permissions if needed + if prev is not None: + # TODO: selinux, ACLs, xattr? + os.chmod(dest, prev.st_mode) + os.chown(dest, prev.st_uid, prev.st_gid) + + def _editor_shell_command(self, filename): + env_editor = os.environ.get('EDITOR', 'vi') + editor = shlex.split(env_editor) + editor.append(filename) + + return editor + + +######################################## +# CIPHERS # +######################################## + +class VaultAES256: + + """ + Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. + Keys are derived using PBKDF2 + """ + + # http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html + + # Note: strings in this class should be byte strings by default. + + def __init__(self): + if not HAS_CRYPTOGRAPHY: + raise AnsibleError(NEED_CRYPTO_LIBRARY) + + @staticmethod + def _create_key_cryptography(b_password, b_salt, key_length, iv_length): + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=2 * key_length + iv_length, + salt=b_salt, + iterations=10000, + backend=CRYPTOGRAPHY_BACKEND) + b_derivedkey = kdf.derive(b_password) + + return b_derivedkey + + @classmethod + def _gen_key_initctr(cls, b_password, b_salt): + # 16 for AES 128, 32 for AES256 + key_length = 32 + + if HAS_CRYPTOGRAPHY: + # AES is a 128-bit block cipher, so IVs and counter nonces are 16 bytes + iv_length = algorithms.AES.block_size // 8 + + b_derivedkey = cls._create_key_cryptography(b_password, b_salt, key_length, iv_length) + b_iv = b_derivedkey[(key_length * 2):(key_length * 2) + iv_length] + else: + raise AnsibleError(NEED_CRYPTO_LIBRARY + '(Detected in initctr)') + + b_key1 = b_derivedkey[:key_length] + b_key2 = b_derivedkey[key_length:(key_length * 2)] + + return b_key1, b_key2, b_iv + + @staticmethod + def _encrypt_cryptography(b_plaintext, b_key1, b_key2, b_iv): + cipher = C_Cipher(algorithms.AES(b_key1), modes.CTR(b_iv), CRYPTOGRAPHY_BACKEND) + encryptor = cipher.encryptor() + padder = padding.PKCS7(algorithms.AES.block_size).padder() + b_ciphertext = encryptor.update(padder.update(b_plaintext) + padder.finalize()) + b_ciphertext += encryptor.finalize() + + # COMBINE SALT, DIGEST AND DATA + hmac = HMAC(b_key2, hashes.SHA256(), CRYPTOGRAPHY_BACKEND) + hmac.update(b_ciphertext) + b_hmac = hmac.finalize() + + return to_bytes(hexlify(b_hmac), errors='surrogate_or_strict'), hexlify(b_ciphertext) + + @classmethod + def encrypt(cls, b_plaintext, secret, salt=None): + + if secret is None: + raise AnsibleVaultError('The secret passed to encrypt() was None') + + if salt is None: + b_salt = os.urandom(32) + elif not salt: + raise AnsibleVaultError('Empty or invalid salt passed to encrypt()') + else: + b_salt = to_bytes(salt) + + b_password = secret.bytes + b_key1, b_key2, b_iv = cls._gen_key_initctr(b_password, b_salt) + + if HAS_CRYPTOGRAPHY: + b_hmac, b_ciphertext = cls._encrypt_cryptography(b_plaintext, b_key1, b_key2, b_iv) + else: + raise AnsibleError(NEED_CRYPTO_LIBRARY + '(Detected in encrypt)') + + b_vaulttext = b'\n'.join([hexlify(b_salt), b_hmac, b_ciphertext]) + # Unnecessary but getting rid of it is a backwards incompatible vault + # format change + b_vaulttext = hexlify(b_vaulttext) + return b_vaulttext + + @classmethod + def _decrypt_cryptography(cls, b_ciphertext, b_crypted_hmac, b_key1, b_key2, b_iv): + # b_key1, b_key2, b_iv = self._gen_key_initctr(b_password, b_salt) + # EXIT EARLY IF DIGEST DOESN'T MATCH + hmac = HMAC(b_key2, hashes.SHA256(), CRYPTOGRAPHY_BACKEND) + hmac.update(b_ciphertext) + try: + hmac.verify(_unhexlify(b_crypted_hmac)) + except InvalidSignature as e: + raise AnsibleVaultError('HMAC verification failed: %s' % e) + + cipher = C_Cipher(algorithms.AES(b_key1), modes.CTR(b_iv), CRYPTOGRAPHY_BACKEND) + decryptor = cipher.decryptor() + unpadder = padding.PKCS7(128).unpadder() + b_plaintext = unpadder.update( + decryptor.update(b_ciphertext) + decryptor.finalize() + ) + unpadder.finalize() + + return b_plaintext + + @staticmethod + def _is_equal(b_a, b_b): + """ + Comparing 2 byte arrays in constant time to avoid timing attacks. + + It would be nice if there were a library for this but hey. + """ + if not (isinstance(b_a, binary_type) and isinstance(b_b, binary_type)): + raise TypeError('_is_equal can only be used to compare two byte strings') + + # http://codahale.com/a-lesson-in-timing-attacks/ + if len(b_a) != len(b_b): + return False + + result = 0 + for b_x, b_y in zip(b_a, b_b): + result |= b_x ^ b_y + return result == 0 + + @classmethod + def decrypt(cls, b_vaulttext, secret): + + b_ciphertext, b_salt, b_crypted_hmac = parse_vaulttext(b_vaulttext) + + # TODO: would be nice if a VaultSecret could be passed directly to _decrypt_* + # (move _gen_key_initctr() to a AES256 VaultSecret or VaultContext impl?) + # though, likely needs to be python cryptography specific impl that basically + # creates a Cipher() with b_key1, a Mode.CTR() with b_iv, and a HMAC() with sign key b_key2 + b_password = secret.bytes + + b_key1, b_key2, b_iv = cls._gen_key_initctr(b_password, b_salt) + + if HAS_CRYPTOGRAPHY: + b_plaintext = cls._decrypt_cryptography(b_ciphertext, b_crypted_hmac, b_key1, b_key2, b_iv) + else: + raise AnsibleError(NEED_CRYPTO_LIBRARY + '(Detected in decrypt)') + + return b_plaintext + + +# Keys could be made bytes later if the code that gets the data is more +# naturally byte-oriented +CIPHER_MAPPING = { + u'AES256': VaultAES256, +} diff --git a/lib/ansible/parsing/yaml/__init__.py b/lib/ansible/parsing/yaml/__init__.py new file mode 100644 index 0000000..ae8ccff --- /dev/null +++ b/lib/ansible/parsing/yaml/__init__.py @@ -0,0 +1,20 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type diff --git a/lib/ansible/parsing/yaml/constructor.py b/lib/ansible/parsing/yaml/constructor.py new file mode 100644 index 0000000..4b79578 --- /dev/null +++ b/lib/ansible/parsing/yaml/constructor.py @@ -0,0 +1,178 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from yaml.constructor import SafeConstructor, ConstructorError +from yaml.nodes import MappingNode + +from ansible import constants as C +from ansible.module_utils._text import to_bytes, to_native +from ansible.parsing.yaml.objects import AnsibleMapping, AnsibleSequence, AnsibleUnicode, AnsibleVaultEncryptedUnicode +from ansible.parsing.vault import VaultLib +from ansible.utils.display import Display +from ansible.utils.unsafe_proxy import wrap_var + +display = Display() + + +class AnsibleConstructor(SafeConstructor): + def __init__(self, file_name=None, vault_secrets=None): + self._ansible_file_name = file_name + super(AnsibleConstructor, self).__init__() + self._vaults = {} + self.vault_secrets = vault_secrets or [] + self._vaults['default'] = VaultLib(secrets=self.vault_secrets) + + def construct_yaml_map(self, node): + data = AnsibleMapping() + yield data + value = self.construct_mapping(node) + data.update(value) + data.ansible_pos = self._node_position_info(node) + + def construct_mapping(self, node, deep=False): + # Most of this is from yaml.constructor.SafeConstructor. We replicate + # it here so that we can warn users when they have duplicate dict keys + # (pyyaml silently allows overwriting keys) + if not isinstance(node, MappingNode): + raise ConstructorError(None, None, + "expected a mapping node, but found %s" % node.id, + node.start_mark) + self.flatten_mapping(node) + mapping = AnsibleMapping() + + # Add our extra information to the returned value + mapping.ansible_pos = self._node_position_info(node) + + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) + try: + hash(key) + except TypeError as exc: + raise ConstructorError("while constructing a mapping", node.start_mark, + "found unacceptable key (%s)" % exc, key_node.start_mark) + + if key in mapping: + msg = (u'While constructing a mapping from {1}, line {2}, column {3}, found a duplicate dict key ({0}).' + u' Using last defined value only.'.format(key, *mapping.ansible_pos)) + if C.DUPLICATE_YAML_DICT_KEY == 'warn': + display.warning(msg) + elif C.DUPLICATE_YAML_DICT_KEY == 'error': + raise ConstructorError(context=None, context_mark=None, + problem=to_native(msg), + problem_mark=node.start_mark, + note=None) + else: + # when 'ignore' + display.debug(msg) + + value = self.construct_object(value_node, deep=deep) + mapping[key] = value + + return mapping + + def construct_yaml_str(self, node): + # Override the default string handling function + # to always return unicode objects + value = self.construct_scalar(node) + ret = AnsibleUnicode(value) + + ret.ansible_pos = self._node_position_info(node) + + return ret + + def construct_vault_encrypted_unicode(self, node): + value = self.construct_scalar(node) + b_ciphertext_data = to_bytes(value) + # could pass in a key id here to choose the vault to associate with + # TODO/FIXME: plugin vault selector + vault = self._vaults['default'] + if vault.secrets is None: + raise ConstructorError(context=None, context_mark=None, + problem="found !vault but no vault password provided", + problem_mark=node.start_mark, + note=None) + ret = AnsibleVaultEncryptedUnicode(b_ciphertext_data) + ret.vault = vault + ret.ansible_pos = self._node_position_info(node) + return ret + + def construct_yaml_seq(self, node): + data = AnsibleSequence() + yield data + data.extend(self.construct_sequence(node)) + data.ansible_pos = self._node_position_info(node) + + def construct_yaml_unsafe(self, node): + try: + constructor = getattr(node, 'id', 'object') + if constructor is not None: + constructor = getattr(self, 'construct_%s' % constructor) + except AttributeError: + constructor = self.construct_object + + value = constructor(node) + + return wrap_var(value) + + def _node_position_info(self, node): + # the line number where the previous token has ended (plus empty lines) + # Add one so that the first line is line 1 rather than line 0 + column = node.start_mark.column + 1 + line = node.start_mark.line + 1 + + # in some cases, we may have pre-read the data and then + # passed it to the load() call for YAML, in which case we + # want to override the default datasource (which would be + # '<string>') to the actual filename we read in + datasource = self._ansible_file_name or node.start_mark.name + + return (datasource, line, column) + + +AnsibleConstructor.add_constructor( + u'tag:yaml.org,2002:map', + AnsibleConstructor.construct_yaml_map) + +AnsibleConstructor.add_constructor( + u'tag:yaml.org,2002:python/dict', + AnsibleConstructor.construct_yaml_map) + +AnsibleConstructor.add_constructor( + u'tag:yaml.org,2002:str', + AnsibleConstructor.construct_yaml_str) + +AnsibleConstructor.add_constructor( + u'tag:yaml.org,2002:python/unicode', + AnsibleConstructor.construct_yaml_str) + +AnsibleConstructor.add_constructor( + u'tag:yaml.org,2002:seq', + AnsibleConstructor.construct_yaml_seq) + +AnsibleConstructor.add_constructor( + u'!unsafe', + AnsibleConstructor.construct_yaml_unsafe) + +AnsibleConstructor.add_constructor( + u'!vault', + AnsibleConstructor.construct_vault_encrypted_unicode) + +AnsibleConstructor.add_constructor(u'!vault-encrypted', AnsibleConstructor.construct_vault_encrypted_unicode) diff --git a/lib/ansible/parsing/yaml/dumper.py b/lib/ansible/parsing/yaml/dumper.py new file mode 100644 index 0000000..8701bb8 --- /dev/null +++ b/lib/ansible/parsing/yaml/dumper.py @@ -0,0 +1,122 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import yaml + +from ansible.module_utils.six import text_type, binary_type +from ansible.module_utils.common.yaml import SafeDumper +from ansible.parsing.yaml.objects import AnsibleUnicode, AnsibleSequence, AnsibleMapping, AnsibleVaultEncryptedUnicode +from ansible.utils.unsafe_proxy import AnsibleUnsafeText, AnsibleUnsafeBytes, NativeJinjaUnsafeText, NativeJinjaText +from ansible.template import AnsibleUndefined +from ansible.vars.hostvars import HostVars, HostVarsVars +from ansible.vars.manager import VarsWithSources + + +class AnsibleDumper(SafeDumper): + ''' + A simple stub class that allows us to add representers + for our overridden object types. + ''' + + +def represent_hostvars(self, data): + return self.represent_dict(dict(data)) + + +# Note: only want to represent the encrypted data +def represent_vault_encrypted_unicode(self, data): + return self.represent_scalar(u'!vault', data._ciphertext.decode(), style='|') + + +def represent_unicode(self, data): + return yaml.representer.SafeRepresenter.represent_str(self, text_type(data)) + + +def represent_binary(self, data): + return yaml.representer.SafeRepresenter.represent_binary(self, binary_type(data)) + + +def represent_undefined(self, data): + # Here bool will ensure _fail_with_undefined_error happens + # if the value is Undefined. + # This happens because Jinja sets __bool__ on StrictUndefined + return bool(data) + + +AnsibleDumper.add_representer( + AnsibleUnicode, + represent_unicode, +) + +AnsibleDumper.add_representer( + AnsibleUnsafeText, + represent_unicode, +) + +AnsibleDumper.add_representer( + AnsibleUnsafeBytes, + represent_binary, +) + +AnsibleDumper.add_representer( + HostVars, + represent_hostvars, +) + +AnsibleDumper.add_representer( + HostVarsVars, + represent_hostvars, +) + +AnsibleDumper.add_representer( + VarsWithSources, + represent_hostvars, +) + +AnsibleDumper.add_representer( + AnsibleSequence, + yaml.representer.SafeRepresenter.represent_list, +) + +AnsibleDumper.add_representer( + AnsibleMapping, + yaml.representer.SafeRepresenter.represent_dict, +) + +AnsibleDumper.add_representer( + AnsibleVaultEncryptedUnicode, + represent_vault_encrypted_unicode, +) + +AnsibleDumper.add_representer( + AnsibleUndefined, + represent_undefined, +) + +AnsibleDumper.add_representer( + NativeJinjaUnsafeText, + represent_unicode, +) + +AnsibleDumper.add_representer( + NativeJinjaText, + represent_unicode, +) diff --git a/lib/ansible/parsing/yaml/loader.py b/lib/ansible/parsing/yaml/loader.py new file mode 100644 index 0000000..15bde79 --- /dev/null +++ b/lib/ansible/parsing/yaml/loader.py @@ -0,0 +1,45 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from yaml.resolver import Resolver + +from ansible.parsing.yaml.constructor import AnsibleConstructor +from ansible.module_utils.common.yaml import HAS_LIBYAML, Parser + +if HAS_LIBYAML: + class AnsibleLoader(Parser, AnsibleConstructor, Resolver): # type: ignore[misc] # pylint: disable=inconsistent-mro + def __init__(self, stream, file_name=None, vault_secrets=None): + Parser.__init__(self, stream) + AnsibleConstructor.__init__(self, file_name=file_name, vault_secrets=vault_secrets) + Resolver.__init__(self) +else: + from yaml.composer import Composer + from yaml.reader import Reader + from yaml.scanner import Scanner + + class AnsibleLoader(Reader, Scanner, Parser, Composer, AnsibleConstructor, Resolver): # type: ignore[misc,no-redef] # pylint: disable=inconsistent-mro + def __init__(self, stream, file_name=None, vault_secrets=None): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + AnsibleConstructor.__init__(self, file_name=file_name, vault_secrets=vault_secrets) + Resolver.__init__(self) diff --git a/lib/ansible/parsing/yaml/objects.py b/lib/ansible/parsing/yaml/objects.py new file mode 100644 index 0000000..a2e2a66 --- /dev/null +++ b/lib/ansible/parsing/yaml/objects.py @@ -0,0 +1,365 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import string +import sys as _sys + +from collections.abc import Sequence + +import sys +import yaml + +from ansible.module_utils.six import text_type +from ansible.module_utils._text import to_bytes, to_text, to_native + + +class AnsibleBaseYAMLObject(object): + ''' + the base class used to sub-class python built-in objects + so that we can add attributes to them during yaml parsing + + ''' + _data_source = None + _line_number = 0 + _column_number = 0 + + def _get_ansible_position(self): + return (self._data_source, self._line_number, self._column_number) + + def _set_ansible_position(self, obj): + try: + (src, line, col) = obj + except (TypeError, ValueError): + raise AssertionError( + 'ansible_pos can only be set with a tuple/list ' + 'of three values: source, line number, column number' + ) + self._data_source = src + self._line_number = line + self._column_number = col + + ansible_pos = property(_get_ansible_position, _set_ansible_position) + + +class AnsibleMapping(AnsibleBaseYAMLObject, dict): + ''' sub class for dictionaries ''' + pass + + +class AnsibleUnicode(AnsibleBaseYAMLObject, text_type): + ''' sub class for unicode objects ''' + pass + + +class AnsibleSequence(AnsibleBaseYAMLObject, list): + ''' sub class for lists ''' + pass + + +class AnsibleVaultEncryptedUnicode(Sequence, AnsibleBaseYAMLObject): + '''Unicode like object that is not evaluated (decrypted) until it needs to be''' + __UNSAFE__ = True + __ENCRYPTED__ = True + yaml_tag = u'!vault' + + @classmethod + def from_plaintext(cls, seq, vault, secret): + if not vault: + raise vault.AnsibleVaultError('Error creating AnsibleVaultEncryptedUnicode, invalid vault (%s) provided' % vault) + + ciphertext = vault.encrypt(seq, secret) + avu = cls(ciphertext) + avu.vault = vault + return avu + + def __init__(self, ciphertext): + '''A AnsibleUnicode with a Vault attribute that can decrypt it. + + ciphertext is a byte string (str on PY2, bytestring on PY3). + + The .data attribute is a property that returns the decrypted plaintext + of the ciphertext as a PY2 unicode or PY3 string object. + ''' + super(AnsibleVaultEncryptedUnicode, self).__init__() + + # after construction, calling code has to set the .vault attribute to a vaultlib object + self.vault = None + self._ciphertext = to_bytes(ciphertext) + + @property + def data(self): + if not self.vault: + return to_text(self._ciphertext) + return to_text(self.vault.decrypt(self._ciphertext, obj=self)) + + @data.setter + def data(self, value): + self._ciphertext = to_bytes(value) + + def is_encrypted(self): + return self.vault and self.vault.is_encrypted(self._ciphertext) + + def __eq__(self, other): + if self.vault: + return other == self.data + return False + + def __ne__(self, other): + if self.vault: + return other != self.data + return True + + def __reversed__(self): + # This gets inerhited from ``collections.Sequence`` which returns a generator + # make this act more like the string implementation + return to_text(self[::-1], errors='surrogate_or_strict') + + def __str__(self): + return to_native(self.data, errors='surrogate_or_strict') + + def __unicode__(self): + return to_text(self.data, errors='surrogate_or_strict') + + def encode(self, encoding=None, errors=None): + return to_bytes(self.data, encoding=encoding, errors=errors) + + # Methods below are a copy from ``collections.UserString`` + # Some are copied as is, where others are modified to not + # auto wrap with ``self.__class__`` + def __repr__(self): + return repr(self.data) + + def __int__(self, base=10): + return int(self.data, base=base) + + def __float__(self): + return float(self.data) + + def __complex__(self): + return complex(self.data) + + def __hash__(self): + return hash(self.data) + + # This breaks vault, do not define it, we cannot satisfy this + # def __getnewargs__(self): + # return (self.data[:],) + + def __lt__(self, string): + if isinstance(string, AnsibleVaultEncryptedUnicode): + return self.data < string.data + return self.data < string + + def __le__(self, string): + if isinstance(string, AnsibleVaultEncryptedUnicode): + return self.data <= string.data + return self.data <= string + + def __gt__(self, string): + if isinstance(string, AnsibleVaultEncryptedUnicode): + return self.data > string.data + return self.data > string + + def __ge__(self, string): + if isinstance(string, AnsibleVaultEncryptedUnicode): + return self.data >= string.data + return self.data >= string + + def __contains__(self, char): + if isinstance(char, AnsibleVaultEncryptedUnicode): + char = char.data + return char in self.data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + def __getslice__(self, start, end): + start = max(start, 0) + end = max(end, 0) + return self.data[start:end] + + def __add__(self, other): + if isinstance(other, AnsibleVaultEncryptedUnicode): + return self.data + other.data + elif isinstance(other, text_type): + return self.data + other + return self.data + to_text(other) + + def __radd__(self, other): + if isinstance(other, text_type): + return other + self.data + return to_text(other) + self.data + + def __mul__(self, n): + return self.data * n + + __rmul__ = __mul__ + + def __mod__(self, args): + return self.data % args + + def __rmod__(self, template): + return to_text(template) % self + + # the following methods are defined in alphabetical order: + def capitalize(self): + return self.data.capitalize() + + def casefold(self): + return self.data.casefold() + + def center(self, width, *args): + return self.data.center(width, *args) + + def count(self, sub, start=0, end=_sys.maxsize): + if isinstance(sub, AnsibleVaultEncryptedUnicode): + sub = sub.data + return self.data.count(sub, start, end) + + def endswith(self, suffix, start=0, end=_sys.maxsize): + return self.data.endswith(suffix, start, end) + + def expandtabs(self, tabsize=8): + return self.data.expandtabs(tabsize) + + def find(self, sub, start=0, end=_sys.maxsize): + if isinstance(sub, AnsibleVaultEncryptedUnicode): + sub = sub.data + return self.data.find(sub, start, end) + + def format(self, *args, **kwds): + return self.data.format(*args, **kwds) + + def format_map(self, mapping): + return self.data.format_map(mapping) + + def index(self, sub, start=0, end=_sys.maxsize): + return self.data.index(sub, start, end) + + def isalpha(self): + return self.data.isalpha() + + def isalnum(self): + return self.data.isalnum() + + def isascii(self): + return self.data.isascii() + + def isdecimal(self): + return self.data.isdecimal() + + def isdigit(self): + return self.data.isdigit() + + def isidentifier(self): + return self.data.isidentifier() + + def islower(self): + return self.data.islower() + + def isnumeric(self): + return self.data.isnumeric() + + def isprintable(self): + return self.data.isprintable() + + def isspace(self): + return self.data.isspace() + + def istitle(self): + return self.data.istitle() + + def isupper(self): + return self.data.isupper() + + def join(self, seq): + return self.data.join(seq) + + def ljust(self, width, *args): + return self.data.ljust(width, *args) + + def lower(self): + return self.data.lower() + + def lstrip(self, chars=None): + return self.data.lstrip(chars) + + maketrans = str.maketrans + + def partition(self, sep): + return self.data.partition(sep) + + def replace(self, old, new, maxsplit=-1): + if isinstance(old, AnsibleVaultEncryptedUnicode): + old = old.data + if isinstance(new, AnsibleVaultEncryptedUnicode): + new = new.data + return self.data.replace(old, new, maxsplit) + + def rfind(self, sub, start=0, end=_sys.maxsize): + if isinstance(sub, AnsibleVaultEncryptedUnicode): + sub = sub.data + return self.data.rfind(sub, start, end) + + def rindex(self, sub, start=0, end=_sys.maxsize): + return self.data.rindex(sub, start, end) + + def rjust(self, width, *args): + return self.data.rjust(width, *args) + + def rpartition(self, sep): + return self.data.rpartition(sep) + + def rstrip(self, chars=None): + return self.data.rstrip(chars) + + def split(self, sep=None, maxsplit=-1): + return self.data.split(sep, maxsplit) + + def rsplit(self, sep=None, maxsplit=-1): + return self.data.rsplit(sep, maxsplit) + + def splitlines(self, keepends=False): + return self.data.splitlines(keepends) + + def startswith(self, prefix, start=0, end=_sys.maxsize): + return self.data.startswith(prefix, start, end) + + def strip(self, chars=None): + return self.data.strip(chars) + + def swapcase(self): + return self.data.swapcase() + + def title(self): + return self.data.title() + + def translate(self, *args): + return self.data.translate(*args) + + def upper(self): + return self.data.upper() + + def zfill(self, width): + return self.data.zfill(width) |