summaryrefslogtreecommitdiffstats
path: root/test/units/parsing
diff options
context:
space:
mode:
Diffstat (limited to 'test/units/parsing')
-rw-r--r--test/units/parsing/__init__.py0
-rw-r--r--test/units/parsing/fixtures/ajson.json19
-rw-r--r--test/units/parsing/fixtures/vault.yml6
-rw-r--r--test/units/parsing/test_ajson.py186
-rw-r--r--test/units/parsing/test_dataloader.py239
-rw-r--r--test/units/parsing/test_mod_args.py137
-rw-r--r--test/units/parsing/test_splitter.py110
-rw-r--r--test/units/parsing/test_unquote.py51
-rw-r--r--test/units/parsing/utils/__init__.py0
-rw-r--r--test/units/parsing/utils/test_addresses.py98
-rw-r--r--test/units/parsing/utils/test_jsonify.py39
-rw-r--r--test/units/parsing/utils/test_yaml.py34
-rw-r--r--test/units/parsing/vault/__init__.py0
-rw-r--r--test/units/parsing/vault/test_vault.py870
-rw-r--r--test/units/parsing/vault/test_vault_editor.py521
-rw-r--r--test/units/parsing/yaml/__init__.py0
-rw-r--r--test/units/parsing/yaml/test_constructor.py84
-rw-r--r--test/units/parsing/yaml/test_dumper.py123
-rw-r--r--test/units/parsing/yaml/test_loader.py432
-rw-r--r--test/units/parsing/yaml/test_objects.py164
20 files changed, 3113 insertions, 0 deletions
diff --git a/test/units/parsing/__init__.py b/test/units/parsing/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/parsing/__init__.py
diff --git a/test/units/parsing/fixtures/ajson.json b/test/units/parsing/fixtures/ajson.json
new file mode 100644
index 0000000..dafec0b
--- /dev/null
+++ b/test/units/parsing/fixtures/ajson.json
@@ -0,0 +1,19 @@
+{
+ "password": {
+ "__ansible_vault": "$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316562356435376162633631326264383934326565333633366238\n3863373264326461623132613931346165636465346337310a326434313830316337393263616439\n64653937313463396366633861363266633465663730303633323534363331316164623237363831\n3536333561393238370a313330316263373938326162386433313336613532653538376662306435\n3339\n"
+ },
+ "bar": {
+ "baz": [
+ {
+ "password": {
+ "__ansible_vault": "$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316562356435376162633631326264383934326565333633366238\n3863373264326461623132613931346165636465346337310a326434313830316337393263616439\n64653937313463396366633861363266633465663730303633323534363331316164623237363831\n3536333561393238370a313330316263373938326162386433313336613532653538376662306435\n3338\n"
+ }
+ }
+ ]
+ },
+ "foo": {
+ "password": {
+ "__ansible_vault": "$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316562356435376162633631326264383934326565333633366238\n3863373264326461623132613931346165636465346337310a326434313830316337393263616439\n64653937313463396366633861363266633465663730303633323534363331316164623237363831\n3536333561393238370a313330316263373938326162386433313336613532653538376662306435\n3339\n"
+ }
+ }
+}
diff --git a/test/units/parsing/fixtures/vault.yml b/test/units/parsing/fixtures/vault.yml
new file mode 100644
index 0000000..ca33ab2
--- /dev/null
+++ b/test/units/parsing/fixtures/vault.yml
@@ -0,0 +1,6 @@
+$ANSIBLE_VAULT;1.1;AES256
+33343734386261666161626433386662623039356366656637303939306563376130623138626165
+6436333766346533353463636566313332623130383662340a393835656134633665333861393331
+37666233346464636263636530626332623035633135363732623332313534306438393366323966
+3135306561356164310a343937653834643433343734653137383339323330626437313562306630
+3035
diff --git a/test/units/parsing/test_ajson.py b/test/units/parsing/test_ajson.py
new file mode 100644
index 0000000..1b9a76b
--- /dev/null
+++ b/test/units/parsing/test_ajson.py
@@ -0,0 +1,186 @@
+# Copyright 2018, Matt Martz <matt@sivel.net>
+# Copyright 2019, Andrew Klychkov @Andersson007 <aaklychkov@mail.ru>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+import os
+import json
+
+import pytest
+
+from collections.abc import Mapping
+from datetime import date, datetime, timezone, timedelta
+
+from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder
+from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode
+from ansible.utils.unsafe_proxy import AnsibleUnsafeText
+
+
+def test_AnsibleJSONDecoder_vault():
+ with open(os.path.join(os.path.dirname(__file__), 'fixtures/ajson.json')) as f:
+ data = json.load(f, cls=AnsibleJSONDecoder)
+
+ assert isinstance(data['password'], AnsibleVaultEncryptedUnicode)
+ assert isinstance(data['bar']['baz'][0]['password'], AnsibleVaultEncryptedUnicode)
+ assert isinstance(data['foo']['password'], AnsibleVaultEncryptedUnicode)
+
+
+def test_encode_decode_unsafe():
+ data = {
+ 'key_value': AnsibleUnsafeText(u'{#NOTACOMMENT#}'),
+ 'list': [AnsibleUnsafeText(u'{#NOTACOMMENT#}')],
+ 'list_dict': [{'key_value': AnsibleUnsafeText(u'{#NOTACOMMENT#}')}]}
+ json_expected = (
+ '{"key_value": {"__ansible_unsafe": "{#NOTACOMMENT#}"}, '
+ '"list": [{"__ansible_unsafe": "{#NOTACOMMENT#}"}], '
+ '"list_dict": [{"key_value": {"__ansible_unsafe": "{#NOTACOMMENT#}"}}]}'
+ )
+ assert json.dumps(data, cls=AnsibleJSONEncoder, preprocess_unsafe=True, sort_keys=True) == json_expected
+ assert json.loads(json_expected, cls=AnsibleJSONDecoder) == data
+
+
+def vault_data():
+ """
+ Prepare AnsibleVaultEncryptedUnicode test data for AnsibleJSONEncoder.default().
+
+ Return a list of tuples (input, expected).
+ """
+
+ with open(os.path.join(os.path.dirname(__file__), 'fixtures/ajson.json')) as f:
+ data = json.load(f, cls=AnsibleJSONDecoder)
+
+ data_0 = data['password']
+ data_1 = data['bar']['baz'][0]['password']
+
+ expected_0 = (u'$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316'
+ '562356435376162633631326264383934326565333633366238\n3863'
+ '373264326461623132613931346165636465346337310a32643431383'
+ '0316337393263616439\n646539373134633963666338613632666334'
+ '65663730303633323534363331316164623237363831\n35363335613'
+ '93238370a313330316263373938326162386433313336613532653538'
+ '376662306435\n3339\n')
+
+ expected_1 = (u'$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316'
+ '562356435376162633631326264383934326565333633366238\n3863'
+ '373264326461623132613931346165636465346337310a32643431383'
+ '0316337393263616439\n646539373134633963666338613632666334'
+ '65663730303633323534363331316164623237363831\n35363335613'
+ '93238370a313330316263373938326162386433313336613532653538'
+ '376662306435\n3338\n')
+
+ return [
+ (data_0, expected_0),
+ (data_1, expected_1),
+ ]
+
+
+class TestAnsibleJSONEncoder:
+
+ """
+ Namespace for testing AnsibleJSONEncoder.
+ """
+
+ @pytest.fixture(scope='class')
+ def mapping(self, request):
+ """
+ Returns object of Mapping mock class.
+
+ The object is used for testing handling of Mapping objects
+ in AnsibleJSONEncoder.default().
+ Using a plain dictionary instead is not suitable because
+ it is handled by default encoder of the superclass (json.JSONEncoder).
+ """
+
+ class M(Mapping):
+
+ """Mock mapping class."""
+
+ def __init__(self, *args, **kwargs):
+ self.__dict__.update(*args, **kwargs)
+
+ def __getitem__(self, key):
+ return self.__dict__[key]
+
+ def __iter__(self):
+ return iter(self.__dict__)
+
+ def __len__(self):
+ return len(self.__dict__)
+
+ return M(request.param)
+
+ @pytest.fixture
+ def ansible_json_encoder(self):
+ """Return AnsibleJSONEncoder object."""
+ return AnsibleJSONEncoder()
+
+ ###############
+ # Test methods:
+
+ @pytest.mark.parametrize(
+ 'test_input,expected',
+ [
+ (datetime(2019, 5, 14, 13, 39, 38, 569047), '2019-05-14T13:39:38.569047'),
+ (datetime(2019, 5, 14, 13, 47, 16, 923866), '2019-05-14T13:47:16.923866'),
+ (date(2019, 5, 14), '2019-05-14'),
+ (date(2020, 5, 14), '2020-05-14'),
+ (datetime(2019, 6, 15, 14, 45, tzinfo=timezone.utc), '2019-06-15T14:45:00+00:00'),
+ (datetime(2019, 6, 15, 14, 45, tzinfo=timezone(timedelta(hours=1, minutes=40))), '2019-06-15T14:45:00+01:40'),
+ ]
+ )
+ def test_date_datetime(self, ansible_json_encoder, test_input, expected):
+ """
+ Test for passing datetime.date or datetime.datetime objects to AnsibleJSONEncoder.default().
+ """
+ assert ansible_json_encoder.default(test_input) == expected
+
+ @pytest.mark.parametrize(
+ 'mapping,expected',
+ [
+ ({1: 1}, {1: 1}),
+ ({2: 2}, {2: 2}),
+ ({1: 2}, {1: 2}),
+ ({2: 1}, {2: 1}),
+ ], indirect=['mapping'],
+ )
+ def test_mapping(self, ansible_json_encoder, mapping, expected):
+ """
+ Test for passing Mapping object to AnsibleJSONEncoder.default().
+ """
+ assert ansible_json_encoder.default(mapping) == expected
+
+ @pytest.mark.parametrize('test_input,expected', vault_data())
+ def test_ansible_json_decoder_vault(self, ansible_json_encoder, test_input, expected):
+ """
+ Test for passing AnsibleVaultEncryptedUnicode to AnsibleJSONEncoder.default().
+ """
+ assert ansible_json_encoder.default(test_input) == {'__ansible_vault': expected}
+ assert json.dumps(test_input, cls=AnsibleJSONEncoder, preprocess_unsafe=True) == '{"__ansible_vault": "%s"}' % expected.replace('\n', '\\n')
+
+ @pytest.mark.parametrize(
+ 'test_input,expected',
+ [
+ ({1: 'first'}, {1: 'first'}),
+ ({2: 'second'}, {2: 'second'}),
+ ]
+ )
+ def test_default_encoder(self, ansible_json_encoder, test_input, expected):
+ """
+ Test for the default encoder of AnsibleJSONEncoder.default().
+
+ If objects of different classes that are not tested above were passed,
+ AnsibleJSONEncoder.default() invokes 'default()' method of json.JSONEncoder superclass.
+ """
+ assert ansible_json_encoder.default(test_input) == expected
+
+ @pytest.mark.parametrize('test_input', [1, 1.1, 'string', [1, 2], set('set'), True, None])
+ def test_default_encoder_unserializable(self, ansible_json_encoder, test_input):
+ """
+ Test for the default encoder of AnsibleJSONEncoder.default(), not serializable objects.
+
+ It must fail with TypeError 'object is not serializable'.
+ """
+ with pytest.raises(TypeError):
+ ansible_json_encoder.default(test_input)
diff --git a/test/units/parsing/test_dataloader.py b/test/units/parsing/test_dataloader.py
new file mode 100644
index 0000000..9ec49a8
--- /dev/null
+++ b/test/units/parsing/test_dataloader.py
@@ -0,0 +1,239 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+
+from units.compat import unittest
+from unittest.mock import patch, mock_open
+from ansible.errors import AnsibleParserError, yaml_strings, AnsibleFileNotFound
+from ansible.parsing.vault import AnsibleVaultError
+from ansible.module_utils._text import to_text
+from ansible.module_utils.six import PY3
+
+from units.mock.vault_helper import TextVaultSecret
+from ansible.parsing.dataloader import DataLoader
+
+from units.mock.path import mock_unfrackpath_noop
+
+
+class TestDataLoader(unittest.TestCase):
+
+ def setUp(self):
+ self._loader = DataLoader()
+
+ @patch('os.path.exists')
+ def test__is_role(self, p_exists):
+ p_exists.side_effect = lambda p: p == b'test_path/tasks/main.yml'
+ self.assertTrue(self._loader._is_role('test_path/tasks'))
+ self.assertTrue(self._loader._is_role('test_path/'))
+
+ @patch.object(DataLoader, '_get_file_contents')
+ def test_parse_json_from_file(self, mock_def):
+ mock_def.return_value = (b"""{"a": 1, "b": 2, "c": 3}""", True)
+ output = self._loader.load_from_file('dummy_json.txt')
+ self.assertEqual(output, dict(a=1, b=2, c=3))
+
+ @patch.object(DataLoader, '_get_file_contents')
+ def test_parse_yaml_from_file(self, mock_def):
+ mock_def.return_value = (b"""
+ a: 1
+ b: 2
+ c: 3
+ """, True)
+ output = self._loader.load_from_file('dummy_yaml.txt')
+ self.assertEqual(output, dict(a=1, b=2, c=3))
+
+ @patch.object(DataLoader, '_get_file_contents')
+ def test_parse_fail_from_file(self, mock_def):
+ mock_def.return_value = (b"""
+ TEXT:
+ ***
+ NOT VALID
+ """, True)
+ self.assertRaises(AnsibleParserError, self._loader.load_from_file, 'dummy_yaml_bad.txt')
+
+ @patch('ansible.errors.AnsibleError._get_error_lines_from_file')
+ @patch.object(DataLoader, '_get_file_contents')
+ def test_tab_error(self, mock_def, mock_get_error_lines):
+ mock_def.return_value = (u"""---\nhosts: localhost\nvars:\n foo: bar\n\tblip: baz""", True)
+ mock_get_error_lines.return_value = ('''\tblip: baz''', '''..foo: bar''')
+ with self.assertRaises(AnsibleParserError) as cm:
+ self._loader.load_from_file('dummy_yaml_text.txt')
+ self.assertIn(yaml_strings.YAML_COMMON_LEADING_TAB_ERROR, str(cm.exception))
+ self.assertIn('foo: bar', str(cm.exception))
+
+ @patch('ansible.parsing.dataloader.unfrackpath', mock_unfrackpath_noop)
+ @patch.object(DataLoader, '_is_role')
+ def test_path_dwim_relative(self, mock_is_role):
+ """
+ simulate a nested dynamic include:
+
+ playbook.yml:
+ - hosts: localhost
+ roles:
+ - { role: 'testrole' }
+
+ testrole/tasks/main.yml:
+ - include: "include1.yml"
+ static: no
+
+ testrole/tasks/include1.yml:
+ - include: include2.yml
+ static: no
+
+ testrole/tasks/include2.yml:
+ - debug: msg="blah"
+ """
+ mock_is_role.return_value = False
+ with patch('os.path.exists') as mock_os_path_exists:
+ mock_os_path_exists.return_value = False
+ self._loader.path_dwim_relative('/tmp/roles/testrole/tasks', 'tasks', 'included2.yml')
+
+ # Fetch first args for every call
+ # mock_os_path_exists.assert_any_call isn't used because os.path.normpath must be used in order to compare paths
+ called_args = [os.path.normpath(to_text(call[0][0])) for call in mock_os_path_exists.call_args_list]
+
+ # 'path_dwim_relative' docstrings say 'with or without explicitly named dirname subdirs':
+ self.assertIn('/tmp/roles/testrole/tasks/included2.yml', called_args)
+ self.assertIn('/tmp/roles/testrole/tasks/tasks/included2.yml', called_args)
+
+ # relative directories below are taken in account too:
+ self.assertIn('tasks/included2.yml', called_args)
+ self.assertIn('included2.yml', called_args)
+
+ def test_path_dwim_root(self):
+ self.assertEqual(self._loader.path_dwim('/'), '/')
+
+ def test_path_dwim_home(self):
+ self.assertEqual(self._loader.path_dwim('~'), os.path.expanduser('~'))
+
+ def test_path_dwim_tilde_slash(self):
+ self.assertEqual(self._loader.path_dwim('~/'), os.path.expanduser('~'))
+
+ def test_get_real_file(self):
+ self.assertEqual(self._loader.get_real_file(__file__), __file__)
+
+ def test_is_file(self):
+ self.assertTrue(self._loader.is_file(__file__))
+
+ def test_is_directory_positive(self):
+ self.assertTrue(self._loader.is_directory(os.path.dirname(__file__)))
+
+ def test_get_file_contents_none_path(self):
+ self.assertRaisesRegex(AnsibleParserError, 'Invalid filename',
+ self._loader._get_file_contents, None)
+
+ def test_get_file_contents_non_existent_path(self):
+ self.assertRaises(AnsibleFileNotFound, self._loader._get_file_contents, '/non_existent_file')
+
+
+class TestPathDwimRelativeDataLoader(unittest.TestCase):
+
+ def setUp(self):
+ self._loader = DataLoader()
+
+ def test_all_slash(self):
+ self.assertEqual(self._loader.path_dwim_relative('/', '/', '/'), '/')
+
+ def test_path_endswith_role(self):
+ self.assertEqual(self._loader.path_dwim_relative(path='foo/bar/tasks/', dirname='/', source='/'), '/')
+
+ def test_path_endswith_role_main_yml(self):
+ self.assertIn('main.yml', self._loader.path_dwim_relative(path='foo/bar/tasks/', dirname='/', source='main.yml'))
+
+ def test_path_endswith_role_source_tilde(self):
+ self.assertEqual(self._loader.path_dwim_relative(path='foo/bar/tasks/', dirname='/', source='~/'), os.path.expanduser('~'))
+
+
+class TestPathDwimRelativeStackDataLoader(unittest.TestCase):
+
+ def setUp(self):
+ self._loader = DataLoader()
+
+ def test_none(self):
+ self.assertRaisesRegex(AnsibleFileNotFound, 'on the Ansible Controller', self._loader.path_dwim_relative_stack, None, None, None)
+
+ def test_empty_strings(self):
+ self.assertEqual(self._loader.path_dwim_relative_stack('', '', ''), './')
+
+ def test_empty_lists(self):
+ self.assertEqual(self._loader.path_dwim_relative_stack([], '', '~/'), os.path.expanduser('~'))
+
+ def test_all_slash(self):
+ self.assertEqual(self._loader.path_dwim_relative_stack('/', '/', '/'), '/')
+
+ def test_path_endswith_role(self):
+ self.assertEqual(self._loader.path_dwim_relative_stack(paths=['foo/bar/tasks/'], dirname='/', source='/'), '/')
+
+ def test_path_endswith_role_source_tilde(self):
+ self.assertEqual(self._loader.path_dwim_relative_stack(paths=['foo/bar/tasks/'], dirname='/', source='~/'), os.path.expanduser('~'))
+
+ def test_path_endswith_role_source_main_yml(self):
+ self.assertRaises(AnsibleFileNotFound, self._loader.path_dwim_relative_stack, ['foo/bar/tasks/'], '/', 'main.yml')
+
+ def test_path_endswith_role_source_main_yml_source_in_dirname(self):
+ self.assertRaises(AnsibleFileNotFound, self._loader.path_dwim_relative_stack, 'foo/bar/tasks/', 'tasks', 'tasks/main.yml')
+
+
+class TestDataLoaderWithVault(unittest.TestCase):
+
+ def setUp(self):
+ self._loader = DataLoader()
+ vault_secrets = [('default', TextVaultSecret('ansible'))]
+ self._loader.set_vault_secrets(vault_secrets)
+ self.test_vault_data_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'vault.yml')
+
+ def tearDown(self):
+ pass
+
+ def test_get_real_file_vault(self):
+ real_file_path = self._loader.get_real_file(self.test_vault_data_path)
+ self.assertTrue(os.path.exists(real_file_path))
+
+ def test_get_real_file_vault_no_vault(self):
+ self._loader.set_vault_secrets(None)
+ self.assertRaises(AnsibleParserError, self._loader.get_real_file, self.test_vault_data_path)
+
+ def test_get_real_file_vault_wrong_password(self):
+ wrong_vault = [('default', TextVaultSecret('wrong_password'))]
+ self._loader.set_vault_secrets(wrong_vault)
+ self.assertRaises(AnsibleVaultError, self._loader.get_real_file, self.test_vault_data_path)
+
+ def test_get_real_file_not_a_path(self):
+ self.assertRaisesRegex(AnsibleParserError, 'Invalid filename', self._loader.get_real_file, None)
+
+ @patch.multiple(DataLoader, path_exists=lambda s, x: True, is_file=lambda s, x: True)
+ def test_parse_from_vault_1_1_file(self):
+ vaulted_data = """$ANSIBLE_VAULT;1.1;AES256
+33343734386261666161626433386662623039356366656637303939306563376130623138626165
+6436333766346533353463636566313332623130383662340a393835656134633665333861393331
+37666233346464636263636530626332623035633135363732623332313534306438393366323966
+3135306561356164310a343937653834643433343734653137383339323330626437313562306630
+3035
+"""
+ if PY3:
+ builtins_name = 'builtins'
+ else:
+ builtins_name = '__builtin__'
+
+ with patch(builtins_name + '.open', mock_open(read_data=vaulted_data.encode('utf-8'))):
+ output = self._loader.load_from_file('dummy_vault.txt')
+ self.assertEqual(output, dict(foo='bar'))
diff --git a/test/units/parsing/test_mod_args.py b/test/units/parsing/test_mod_args.py
new file mode 100644
index 0000000..5d3f5d2
--- /dev/null
+++ b/test/units/parsing/test_mod_args.py
@@ -0,0 +1,137 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+# Copyright 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+import pytest
+import re
+
+from ansible.errors import AnsibleParserError
+from ansible.parsing.mod_args import ModuleArgsParser
+from ansible.utils.sentinel import Sentinel
+
+
+class TestModArgsDwim:
+
+ # TODO: add tests that construct ModuleArgsParser with a task reference
+ # TODO: verify the AnsibleError raised on failure knows the task
+ # and the task knows the line numbers
+
+ INVALID_MULTIPLE_ACTIONS = (
+ ({'action': 'shell echo hi', 'local_action': 'shell echo hi'}, "action and local_action are mutually exclusive"),
+ ({'action': 'shell echo hi', 'shell': 'echo hi'}, "conflicting action statements: shell, shell"),
+ ({'local_action': 'shell echo hi', 'shell': 'echo hi'}, "conflicting action statements: shell, shell"),
+ )
+
+ def _debug(self, mod, args, to):
+ print("RETURNED module = {0}".format(mod))
+ print(" args = {0}".format(args))
+ print(" to = {0}".format(to))
+
+ def test_basic_shell(self):
+ m = ModuleArgsParser(dict(shell='echo hi'))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod == 'shell'
+ assert args == dict(
+ _raw_params='echo hi',
+ )
+ assert to is Sentinel
+
+ def test_basic_command(self):
+ m = ModuleArgsParser(dict(command='echo hi'))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod == 'command'
+ assert args == dict(
+ _raw_params='echo hi',
+ )
+ assert to is Sentinel
+
+ def test_shell_with_modifiers(self):
+ m = ModuleArgsParser(dict(shell='/bin/foo creates=/tmp/baz removes=/tmp/bleep'))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod == 'shell'
+ assert args == dict(
+ creates='/tmp/baz',
+ removes='/tmp/bleep',
+ _raw_params='/bin/foo',
+ )
+ assert to is Sentinel
+
+ def test_normal_usage(self):
+ m = ModuleArgsParser(dict(copy='src=a dest=b'))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod, 'copy'
+ assert args, dict(src='a', dest='b')
+ assert to is Sentinel
+
+ def test_complex_args(self):
+ m = ModuleArgsParser(dict(copy=dict(src='a', dest='b')))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod, 'copy'
+ assert args, dict(src='a', dest='b')
+ assert to is Sentinel
+
+ def test_action_with_complex(self):
+ m = ModuleArgsParser(dict(action=dict(module='copy', src='a', dest='b')))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod == 'copy'
+ assert args == dict(src='a', dest='b')
+ assert to is Sentinel
+
+ def test_action_with_complex_and_complex_args(self):
+ m = ModuleArgsParser(dict(action=dict(module='copy', args=dict(src='a', dest='b'))))
+ mod, args, to = m.parse()
+ self._debug(mod, args, to)
+
+ assert mod == 'copy'
+ assert args == dict(src='a', dest='b')
+ assert to is Sentinel
+
+ def test_local_action_string(self):
+ m = ModuleArgsParser(dict(local_action='copy src=a dest=b'))
+ mod, args, delegate_to = m.parse()
+ self._debug(mod, args, delegate_to)
+
+ assert mod == 'copy'
+ assert args == dict(src='a', dest='b')
+ assert delegate_to == 'localhost'
+
+ @pytest.mark.parametrize("args_dict, msg", INVALID_MULTIPLE_ACTIONS)
+ def test_multiple_actions(self, args_dict, msg):
+ m = ModuleArgsParser(args_dict)
+ with pytest.raises(AnsibleParserError) as err:
+ m.parse()
+
+ assert err.value.args[0] == msg
+
+ def test_multiple_actions_ping_shell(self):
+ args_dict = {'ping': 'data=hi', 'shell': 'echo hi'}
+ m = ModuleArgsParser(args_dict)
+ with pytest.raises(AnsibleParserError) as err:
+ m.parse()
+
+ assert err.value.args[0].startswith("conflicting action statements: ")
+ actions = set(re.search(r'(\w+), (\w+)', err.value.args[0]).groups())
+ assert actions == set(['ping', 'shell'])
+
+ def test_bogus_action(self):
+ args_dict = {'bogusaction': {}}
+ m = ModuleArgsParser(args_dict)
+ with pytest.raises(AnsibleParserError) as err:
+ m.parse()
+
+ assert err.value.args[0].startswith("couldn't resolve module/action 'bogusaction'")
diff --git a/test/units/parsing/test_splitter.py b/test/units/parsing/test_splitter.py
new file mode 100644
index 0000000..a37de0f
--- /dev/null
+++ b/test/units/parsing/test_splitter.py
@@ -0,0 +1,110 @@
+# coding: utf-8
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from ansible.parsing.splitter import split_args, parse_kv
+
+import pytest
+
+SPLIT_DATA = (
+ (u'a',
+ [u'a'],
+ {u'_raw_params': u'a'}),
+ (u'a=b',
+ [u'a=b'],
+ {u'a': u'b'}),
+ (u'a="foo bar"',
+ [u'a="foo bar"'],
+ {u'a': u'foo bar'}),
+ (u'"foo bar baz"',
+ [u'"foo bar baz"'],
+ {u'_raw_params': '"foo bar baz"'}),
+ (u'foo bar baz',
+ [u'foo', u'bar', u'baz'],
+ {u'_raw_params': u'foo bar baz'}),
+ (u'a=b c="foo bar"',
+ [u'a=b', u'c="foo bar"'],
+ {u'a': u'b', u'c': u'foo bar'}),
+ (u'a="echo \\"hello world\\"" b=bar',
+ [u'a="echo \\"hello world\\""', u'b=bar'],
+ {u'a': u'echo "hello world"', u'b': u'bar'}),
+ (u'a="multi\nline"',
+ [u'a="multi\nline"'],
+ {u'a': u'multi\nline'}),
+ (u'a="blank\n\nline"',
+ [u'a="blank\n\nline"'],
+ {u'a': u'blank\n\nline'}),
+ (u'a="blank\n\n\nlines"',
+ [u'a="blank\n\n\nlines"'],
+ {u'a': u'blank\n\n\nlines'}),
+ (u'a="a long\nmessage\\\nabout a thing\n"',
+ [u'a="a long\nmessage\\\nabout a thing\n"'],
+ {u'a': u'a long\nmessage\\\nabout a thing\n'}),
+ (u'a="multiline\nmessage1\\\n" b="multiline\nmessage2\\\n"',
+ [u'a="multiline\nmessage1\\\n"', u'b="multiline\nmessage2\\\n"'],
+ {u'a': 'multiline\nmessage1\\\n', u'b': u'multiline\nmessage2\\\n'}),
+ (u'a={{jinja}}',
+ [u'a={{jinja}}'],
+ {u'a': u'{{jinja}}'}),
+ (u'a={{ jinja }}',
+ [u'a={{ jinja }}'],
+ {u'a': u'{{ jinja }}'}),
+ (u'a="{{jinja}}"',
+ [u'a="{{jinja}}"'],
+ {u'a': u'{{jinja}}'}),
+ (u'a={{ jinja }}{{jinja2}}',
+ [u'a={{ jinja }}{{jinja2}}'],
+ {u'a': u'{{ jinja }}{{jinja2}}'}),
+ (u'a="{{ jinja }}{{jinja2}}"',
+ [u'a="{{ jinja }}{{jinja2}}"'],
+ {u'a': u'{{ jinja }}{{jinja2}}'}),
+ (u'a={{jinja}} b={{jinja2}}',
+ [u'a={{jinja}}', u'b={{jinja2}}'],
+ {u'a': u'{{jinja}}', u'b': u'{{jinja2}}'}),
+ (u'a="{{jinja}}\n" b="{{jinja2}}\n"',
+ [u'a="{{jinja}}\n"', u'b="{{jinja2}}\n"'],
+ {u'a': u'{{jinja}}\n', u'b': u'{{jinja2}}\n'}),
+ (u'a="café eñyei"',
+ [u'a="café eñyei"'],
+ {u'a': u'café eñyei'}),
+ (u'a=café b=eñyei',
+ [u'a=café', u'b=eñyei'],
+ {u'a': u'café', u'b': u'eñyei'}),
+ (u'a={{ foo | some_filter(\' \', " ") }} b=bar',
+ [u'a={{ foo | some_filter(\' \', " ") }}', u'b=bar'],
+ {u'a': u'{{ foo | some_filter(\' \', " ") }}', u'b': u'bar'}),
+ (u'One\n Two\n Three\n',
+ [u'One\n ', u'Two\n ', u'Three\n'],
+ {u'_raw_params': u'One\n Two\n Three\n'}),
+)
+
+SPLIT_ARGS = ((test[0], test[1]) for test in SPLIT_DATA)
+PARSE_KV = ((test[0], test[2]) for test in SPLIT_DATA)
+
+
+@pytest.mark.parametrize("args, expected", SPLIT_ARGS)
+def test_split_args(args, expected):
+ assert split_args(args) == expected
+
+
+@pytest.mark.parametrize("args, expected", PARSE_KV)
+def test_parse_kv(args, expected):
+ assert parse_kv(args) == expected
diff --git a/test/units/parsing/test_unquote.py b/test/units/parsing/test_unquote.py
new file mode 100644
index 0000000..4b4260e
--- /dev/null
+++ b/test/units/parsing/test_unquote.py
@@ -0,0 +1,51 @@
+# coding: utf-8
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from ansible.parsing.quoting import unquote
+
+import pytest
+
+UNQUOTE_DATA = (
+ (u'1', u'1'),
+ (u'\'1\'', u'1'),
+ (u'"1"', u'1'),
+ (u'"1 \'2\'"', u'1 \'2\''),
+ (u'\'1 "2"\'', u'1 "2"'),
+ (u'\'1 \'2\'\'', u'1 \'2\''),
+ (u'"1\\"', u'"1\\"'),
+ (u'\'1\\\'', u'\'1\\\''),
+ (u'"1 \\"2\\" 3"', u'1 \\"2\\" 3'),
+ (u'\'1 \\\'2\\\' 3\'', u'1 \\\'2\\\' 3'),
+ (u'"', u'"'),
+ (u'\'', u'\''),
+ # Not entirely sure these are good but they match the current
+ # behaviour
+ (u'"1""2"', u'1""2'),
+ (u'\'1\'\'2\'', u'1\'\'2'),
+ (u'"1" 2 "3"', u'1" 2 "3'),
+ (u'"1"\'2\'"3"', u'1"\'2\'"3'),
+)
+
+
+@pytest.mark.parametrize("quoted, expected", UNQUOTE_DATA)
+def test_unquote(quoted, expected):
+ assert unquote(quoted) == expected
diff --git a/test/units/parsing/utils/__init__.py b/test/units/parsing/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/parsing/utils/__init__.py
diff --git a/test/units/parsing/utils/test_addresses.py b/test/units/parsing/utils/test_addresses.py
new file mode 100644
index 0000000..4f7304f
--- /dev/null
+++ b/test/units/parsing/utils/test_addresses.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import unittest
+
+from ansible.parsing.utils.addresses import parse_address
+
+
+class TestParseAddress(unittest.TestCase):
+
+ tests = {
+ # IPv4 addresses
+ '192.0.2.3': ['192.0.2.3', None],
+ '192.0.2.3:23': ['192.0.2.3', 23],
+
+ # IPv6 addresses
+ '::': ['::', None],
+ '::1': ['::1', None],
+ '[::1]:442': ['::1', 442],
+ 'abcd:ef98:7654:3210:abcd:ef98:7654:3210': ['abcd:ef98:7654:3210:abcd:ef98:7654:3210', None],
+ '[abcd:ef98:7654:3210:abcd:ef98:7654:3210]:42': ['abcd:ef98:7654:3210:abcd:ef98:7654:3210', 42],
+ '1234:5678:9abc:def0:1234:5678:9abc:def0': ['1234:5678:9abc:def0:1234:5678:9abc:def0', None],
+ '1234::9abc:def0:1234:5678:9abc:def0': ['1234::9abc:def0:1234:5678:9abc:def0', None],
+ '1234:5678::def0:1234:5678:9abc:def0': ['1234:5678::def0:1234:5678:9abc:def0', None],
+ '1234:5678:9abc::1234:5678:9abc:def0': ['1234:5678:9abc::1234:5678:9abc:def0', None],
+ '1234:5678:9abc:def0::5678:9abc:def0': ['1234:5678:9abc:def0::5678:9abc:def0', None],
+ '1234:5678:9abc:def0:1234::9abc:def0': ['1234:5678:9abc:def0:1234::9abc:def0', None],
+ '1234:5678:9abc:def0:1234:5678::def0': ['1234:5678:9abc:def0:1234:5678::def0', None],
+ '1234:5678:9abc:def0:1234:5678::': ['1234:5678:9abc:def0:1234:5678::', None],
+ '::9abc:def0:1234:5678:9abc:def0': ['::9abc:def0:1234:5678:9abc:def0', None],
+ '0:0:0:0:0:ffff:1.2.3.4': ['0:0:0:0:0:ffff:1.2.3.4', None],
+ '0:0:0:0:0:0:1.2.3.4': ['0:0:0:0:0:0:1.2.3.4', None],
+ '::ffff:1.2.3.4': ['::ffff:1.2.3.4', None],
+ '::1.2.3.4': ['::1.2.3.4', None],
+ '1234::': ['1234::', None],
+
+ # Hostnames
+ 'some-host': ['some-host', None],
+ 'some-host:80': ['some-host', 80],
+ 'some.host.com:492': ['some.host.com', 492],
+ '[some.host.com]:493': ['some.host.com', 493],
+ 'a-b.3foo_bar.com:23': ['a-b.3foo_bar.com', 23],
+ u'fóöbär': [u'fóöbär', None],
+ u'fóöbär:32': [u'fóöbär', 32],
+ u'fóöbär.éxàmplê.com:632': [u'fóöbär.éxàmplê.com', 632],
+
+ # Various errors
+ '': [None, None],
+ 'some..host': [None, None],
+ 'some.': [None, None],
+ '[example.com]': [None, None],
+ 'some-': [None, None],
+ 'some-.foo.com': [None, None],
+ 'some.-foo.com': [None, None],
+ }
+
+ range_tests = {
+ '192.0.2.[3:10]': ['192.0.2.[3:10]', None],
+ '192.0.2.[3:10]:23': ['192.0.2.[3:10]', 23],
+ 'abcd:ef98::7654:[1:9]': ['abcd:ef98::7654:[1:9]', None],
+ '[abcd:ef98::7654:[6:32]]:2222': ['abcd:ef98::7654:[6:32]', 2222],
+ '[abcd:ef98::7654:[9ab3:fcb7]]:2222': ['abcd:ef98::7654:[9ab3:fcb7]', 2222],
+ u'fóöb[a:c]r.éxàmplê.com:632': [u'fóöb[a:c]r.éxàmplê.com', 632],
+ '[a:b]foo.com': ['[a:b]foo.com', None],
+ 'foo[a:b].com': ['foo[a:b].com', None],
+ 'foo[a:b]:42': ['foo[a:b]', 42],
+ 'foo[a-b]-.com': [None, None],
+ 'foo[a-b]:32': [None, None],
+ 'foo[x-y]': [None, None],
+ }
+
+ def test_without_ranges(self):
+ for t in self.tests:
+ test = self.tests[t]
+
+ try:
+ (host, port) = parse_address(t)
+ except Exception:
+ host = None
+ port = None
+
+ assert host == test[0]
+ assert port == test[1]
+
+ def test_with_ranges(self):
+ for t in self.range_tests:
+ test = self.range_tests[t]
+
+ try:
+ (host, port) = parse_address(t, allow_ranges=True)
+ except Exception:
+ host = None
+ port = None
+
+ assert host == test[0]
+ assert port == test[1]
diff --git a/test/units/parsing/utils/test_jsonify.py b/test/units/parsing/utils/test_jsonify.py
new file mode 100644
index 0000000..37be782
--- /dev/null
+++ b/test/units/parsing/utils/test_jsonify.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# (c) 2016, James Cammarata <jimi@sngx.net>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from units.compat import unittest
+from ansible.parsing.utils.jsonify import jsonify
+
+
+class TestJsonify(unittest.TestCase):
+ def test_jsonify_simple(self):
+ self.assertEqual(jsonify(dict(a=1, b=2, c=3)), '{"a": 1, "b": 2, "c": 3}')
+
+ def test_jsonify_simple_format(self):
+ res = jsonify(dict(a=1, b=2, c=3), format=True)
+ cleaned = "".join([x.strip() for x in res.splitlines()])
+ self.assertEqual(cleaned, '{"a": 1,"b": 2,"c": 3}')
+
+ def test_jsonify_unicode(self):
+ self.assertEqual(jsonify(dict(toshio=u'くらとみ')), u'{"toshio": "くらとみ"}')
+
+ def test_jsonify_empty(self):
+ self.assertEqual(jsonify(None), '{}')
diff --git a/test/units/parsing/utils/test_yaml.py b/test/units/parsing/utils/test_yaml.py
new file mode 100644
index 0000000..27b2905
--- /dev/null
+++ b/test/units/parsing/utils/test_yaml.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# (c) 2017, Ansible Project
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+
+from ansible.errors import AnsibleParserError
+from ansible.parsing.utils.yaml import from_yaml
+
+
+def test_from_yaml_simple():
+ assert from_yaml(u'---\n- test: 1\n test2: "2"\n- caf\xe9: "caf\xe9"') == [{u'test': 1, u'test2': u"2"}, {u"caf\xe9": u"caf\xe9"}]
+
+
+def test_bad_yaml():
+ with pytest.raises(AnsibleParserError):
+ from_yaml(u'foo: bar: baz')
diff --git a/test/units/parsing/vault/__init__.py b/test/units/parsing/vault/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/parsing/vault/__init__.py
diff --git a/test/units/parsing/vault/test_vault.py b/test/units/parsing/vault/test_vault.py
new file mode 100644
index 0000000..7afd356
--- /dev/null
+++ b/test/units/parsing/vault/test_vault.py
@@ -0,0 +1,870 @@
+# -*- coding: utf-8 -*-
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+# (c) 2016, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import binascii
+import io
+import os
+import tempfile
+
+from binascii import hexlify
+import pytest
+
+from units.compat import unittest
+from unittest.mock import patch, MagicMock
+
+from ansible import errors
+from ansible.module_utils import six
+from ansible.module_utils._text import to_bytes, to_text
+from ansible.parsing import vault
+
+from units.mock.loader import DictDataLoader
+from units.mock.vault_helper import TextVaultSecret
+
+
+class TestUnhexlify(unittest.TestCase):
+ def test(self):
+ b_plain_data = b'some text to hexlify'
+ b_data = hexlify(b_plain_data)
+ res = vault._unhexlify(b_data)
+ self.assertEqual(res, b_plain_data)
+
+ def test_odd_length(self):
+ b_data = b'123456789abcdefghijklmnopqrstuvwxyz'
+
+ self.assertRaisesRegex(vault.AnsibleVaultFormatError,
+ '.*Vault format unhexlify error.*',
+ vault._unhexlify,
+ b_data)
+
+ def test_nonhex(self):
+ b_data = b'6z36316566653264333665333637623064303639353237620a636366633565663263336335656532'
+
+ self.assertRaisesRegex(vault.AnsibleVaultFormatError,
+ '.*Vault format unhexlify error.*Non-hexadecimal digit found',
+ vault._unhexlify,
+ b_data)
+
+
+class TestParseVaulttext(unittest.TestCase):
+ def test(self):
+ vaulttext_envelope = u'''$ANSIBLE_VAULT;1.1;AES256
+33363965326261303234626463623963633531343539616138316433353830356566396130353436
+3562643163366231316662386565383735653432386435610a306664636137376132643732393835
+63383038383730306639353234326630666539346233376330303938323639306661313032396437
+6233623062366136310a633866373936313238333730653739323461656662303864663666653563
+3138'''
+
+ b_vaulttext_envelope = to_bytes(vaulttext_envelope, errors='strict', encoding='utf-8')
+ b_vaulttext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext_envelope)
+ res = vault.parse_vaulttext(b_vaulttext)
+ self.assertIsInstance(res[0], bytes)
+ self.assertIsInstance(res[1], bytes)
+ self.assertIsInstance(res[2], bytes)
+
+ def test_non_hex(self):
+ vaulttext_envelope = u'''$ANSIBLE_VAULT;1.1;AES256
+3336396J326261303234626463623963633531343539616138316433353830356566396130353436
+3562643163366231316662386565383735653432386435610a306664636137376132643732393835
+63383038383730306639353234326630666539346233376330303938323639306661313032396437
+6233623062366136310a633866373936313238333730653739323461656662303864663666653563
+3138'''
+
+ b_vaulttext_envelope = to_bytes(vaulttext_envelope, errors='strict', encoding='utf-8')
+ b_vaulttext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext_envelope)
+ self.assertRaisesRegex(vault.AnsibleVaultFormatError,
+ '.*Vault format unhexlify error.*Non-hexadecimal digit found',
+ vault.parse_vaulttext,
+ b_vaulttext_envelope)
+
+
+class TestVaultSecret(unittest.TestCase):
+ def test(self):
+ secret = vault.VaultSecret()
+ secret.load()
+ self.assertIsNone(secret._bytes)
+
+ def test_bytes(self):
+ some_text = u'私はガラスを食べられます。それは私を傷つけません。'
+ _bytes = to_bytes(some_text)
+ secret = vault.VaultSecret(_bytes)
+ secret.load()
+ self.assertEqual(secret.bytes, _bytes)
+
+
+class TestPromptVaultSecret(unittest.TestCase):
+ def test_empty_prompt_formats(self):
+ secret = vault.PromptVaultSecret(vault_id='test_id', prompt_formats=[])
+ secret.load()
+ self.assertIsNone(secret._bytes)
+
+ @patch('ansible.parsing.vault.display.prompt', return_value='the_password')
+ def test_prompt_formats_none(self, mock_display_prompt):
+ secret = vault.PromptVaultSecret(vault_id='test_id')
+ secret.load()
+ self.assertEqual(secret._bytes, b'the_password')
+
+ @patch('ansible.parsing.vault.display.prompt', return_value='the_password')
+ def test_custom_prompt(self, mock_display_prompt):
+ secret = vault.PromptVaultSecret(vault_id='test_id',
+ prompt_formats=['The cow flies at midnight: '])
+ secret.load()
+ self.assertEqual(secret._bytes, b'the_password')
+
+ @patch('ansible.parsing.vault.display.prompt', side_effect=EOFError)
+ def test_prompt_eoferror(self, mock_display_prompt):
+ secret = vault.PromptVaultSecret(vault_id='test_id')
+ self.assertRaisesRegex(vault.AnsibleVaultError,
+ 'EOFError.*test_id',
+ secret.load)
+
+ @patch('ansible.parsing.vault.display.prompt', side_effect=['first_password', 'second_password'])
+ def test_prompt_passwords_dont_match(self, mock_display_prompt):
+ secret = vault.PromptVaultSecret(vault_id='test_id',
+ prompt_formats=['Vault password: ',
+ 'Confirm Vault password: '])
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'Passwords do not match',
+ secret.load)
+
+
+class TestFileVaultSecret(unittest.TestCase):
+ def setUp(self):
+ self.vault_password = "test-vault-password"
+ text_secret = TextVaultSecret(self.vault_password)
+ self.vault_secrets = [('foo', text_secret)]
+
+ def test(self):
+ secret = vault.FileVaultSecret()
+ self.assertIsNone(secret._bytes)
+ self.assertIsNone(secret._text)
+
+ def test_repr_empty(self):
+ secret = vault.FileVaultSecret()
+ self.assertEqual(repr(secret), "FileVaultSecret()")
+
+ def test_repr(self):
+ tmp_file = tempfile.NamedTemporaryFile(delete=False)
+ fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'})
+
+ secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name)
+ filename = tmp_file.name
+ tmp_file.close()
+ self.assertEqual(repr(secret), "FileVaultSecret(filename='%s')" % filename)
+
+ def test_empty_bytes(self):
+ secret = vault.FileVaultSecret()
+ self.assertIsNone(secret.bytes)
+
+ def test_file(self):
+ password = 'some password'
+
+ tmp_file = tempfile.NamedTemporaryFile(delete=False)
+ tmp_file.write(to_bytes(password))
+ tmp_file.close()
+
+ fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'})
+
+ secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name)
+ secret.load()
+
+ os.unlink(tmp_file.name)
+
+ self.assertEqual(secret.bytes, to_bytes(password))
+
+ def test_file_empty(self):
+
+ tmp_file = tempfile.NamedTemporaryFile(delete=False)
+ tmp_file.write(to_bytes(''))
+ tmp_file.close()
+
+ fake_loader = DictDataLoader({tmp_file.name: ''})
+
+ secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name)
+ self.assertRaisesRegex(vault.AnsibleVaultPasswordError,
+ 'Invalid vault password was provided from file.*%s' % tmp_file.name,
+ secret.load)
+
+ os.unlink(tmp_file.name)
+
+ def test_file_encrypted(self):
+ vault_password = "test-vault-password"
+ text_secret = TextVaultSecret(vault_password)
+ vault_secrets = [('foo', text_secret)]
+
+ password = 'some password'
+ # 'some password' encrypted with 'test-ansible-password'
+
+ password_file_content = '''$ANSIBLE_VAULT;1.1;AES256
+61393863643638653437313566313632306462383837303132346434616433313438353634613762
+3334363431623364386164616163326537366333353663650a663634306232363432626162353665
+39623061353266373631636331643761306665343731376633623439313138396330346237653930
+6432643864346136640a653364386634666461306231353765636662316335613235383565306437
+3737
+'''
+
+ tmp_file = tempfile.NamedTemporaryFile(delete=False)
+ tmp_file.write(to_bytes(password_file_content))
+ tmp_file.close()
+
+ fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'})
+ fake_loader._vault.secrets = vault_secrets
+
+ secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name)
+ secret.load()
+
+ os.unlink(tmp_file.name)
+
+ self.assertEqual(secret.bytes, to_bytes(password))
+
+ def test_file_not_a_directory(self):
+ filename = '/dev/null/foobar'
+ fake_loader = DictDataLoader({filename: 'sdfadf'})
+
+ secret = vault.FileVaultSecret(loader=fake_loader, filename=filename)
+ self.assertRaisesRegex(errors.AnsibleError,
+ '.*Could not read vault password file.*/dev/null/foobar.*Not a directory',
+ secret.load)
+
+ def test_file_not_found(self):
+ tmp_file = tempfile.NamedTemporaryFile()
+ filename = os.path.realpath(tmp_file.name)
+ tmp_file.close()
+
+ fake_loader = DictDataLoader({filename: 'sdfadf'})
+
+ secret = vault.FileVaultSecret(loader=fake_loader, filename=filename)
+ self.assertRaisesRegex(errors.AnsibleError,
+ '.*Could not read vault password file.*%s.*' % filename,
+ secret.load)
+
+
+class TestScriptVaultSecret(unittest.TestCase):
+ def test(self):
+ secret = vault.ScriptVaultSecret()
+ self.assertIsNone(secret._bytes)
+ self.assertIsNone(secret._text)
+
+ def _mock_popen(self, mock_popen, return_code=0, stdout=b'', stderr=b''):
+ def communicate():
+ return stdout, stderr
+ mock_popen.return_value = MagicMock(returncode=return_code)
+ mock_popen_instance = mock_popen.return_value
+ mock_popen_instance.communicate = communicate
+
+ @patch('ansible.parsing.vault.subprocess.Popen')
+ def test_read_file(self, mock_popen):
+ self._mock_popen(mock_popen, stdout=b'some_password')
+ secret = vault.ScriptVaultSecret()
+ with patch.object(secret, 'loader') as mock_loader:
+ mock_loader.is_executable = MagicMock(return_value=True)
+ secret.load()
+
+ @patch('ansible.parsing.vault.subprocess.Popen')
+ def test_read_file_empty(self, mock_popen):
+ self._mock_popen(mock_popen, stdout=b'')
+ secret = vault.ScriptVaultSecret()
+ with patch.object(secret, 'loader') as mock_loader:
+ mock_loader.is_executable = MagicMock(return_value=True)
+ self.assertRaisesRegex(vault.AnsibleVaultPasswordError,
+ 'Invalid vault password was provided from script',
+ secret.load)
+
+ @patch('ansible.parsing.vault.subprocess.Popen')
+ def test_read_file_os_error(self, mock_popen):
+ self._mock_popen(mock_popen)
+ mock_popen.side_effect = OSError('That is not an executable')
+ secret = vault.ScriptVaultSecret()
+ with patch.object(secret, 'loader') as mock_loader:
+ mock_loader.is_executable = MagicMock(return_value=True)
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'Problem running vault password script.*',
+ secret.load)
+
+ @patch('ansible.parsing.vault.subprocess.Popen')
+ def test_read_file_not_executable(self, mock_popen):
+ self._mock_popen(mock_popen)
+ secret = vault.ScriptVaultSecret()
+ with patch.object(secret, 'loader') as mock_loader:
+ mock_loader.is_executable = MagicMock(return_value=False)
+ self.assertRaisesRegex(vault.AnsibleVaultError,
+ 'The vault password script .* was not executable',
+ secret.load)
+
+ @patch('ansible.parsing.vault.subprocess.Popen')
+ def test_read_file_non_zero_return_code(self, mock_popen):
+ stderr = b'That did not work for a random reason'
+ rc = 37
+
+ self._mock_popen(mock_popen, return_code=rc, stderr=stderr)
+ secret = vault.ScriptVaultSecret(filename='/dev/null/some_vault_secret')
+ with patch.object(secret, 'loader') as mock_loader:
+ mock_loader.is_executable = MagicMock(return_value=True)
+ self.assertRaisesRegex(errors.AnsibleError,
+ r'Vault password script.*returned non-zero \(%s\): %s' % (rc, stderr),
+ secret.load)
+
+
+class TestScriptIsClient(unittest.TestCase):
+ def test_randomname(self):
+ filename = 'randomname'
+ res = vault.script_is_client(filename)
+ self.assertFalse(res)
+
+ def test_something_dash_client(self):
+ filename = 'something-client'
+ res = vault.script_is_client(filename)
+ self.assertTrue(res)
+
+ def test_something_dash_client_somethingelse(self):
+ filename = 'something-client-somethingelse'
+ res = vault.script_is_client(filename)
+ self.assertFalse(res)
+
+ def test_something_dash_client_py(self):
+ filename = 'something-client.py'
+ res = vault.script_is_client(filename)
+ self.assertTrue(res)
+
+ def test_full_path_something_dash_client_py(self):
+ filename = '/foo/bar/something-client.py'
+ res = vault.script_is_client(filename)
+ self.assertTrue(res)
+
+ def test_full_path_something_dash_client(self):
+ filename = '/foo/bar/something-client'
+ res = vault.script_is_client(filename)
+ self.assertTrue(res)
+
+ def test_full_path_something_dash_client_in_dir(self):
+ filename = '/foo/bar/something-client/but/not/filename'
+ res = vault.script_is_client(filename)
+ self.assertFalse(res)
+
+
+class TestGetFileVaultSecret(unittest.TestCase):
+ def test_file(self):
+ password = 'some password'
+
+ tmp_file = tempfile.NamedTemporaryFile(delete=False)
+ tmp_file.write(to_bytes(password))
+ tmp_file.close()
+
+ fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'})
+
+ secret = vault.get_file_vault_secret(filename=tmp_file.name, loader=fake_loader)
+ secret.load()
+
+ os.unlink(tmp_file.name)
+
+ self.assertEqual(secret.bytes, to_bytes(password))
+
+ def test_file_not_a_directory(self):
+ filename = '/dev/null/foobar'
+ fake_loader = DictDataLoader({filename: 'sdfadf'})
+
+ self.assertRaisesRegex(errors.AnsibleError,
+ '.*The vault password file %s was not found.*' % filename,
+ vault.get_file_vault_secret,
+ filename=filename,
+ loader=fake_loader)
+
+ def test_file_not_found(self):
+ tmp_file = tempfile.NamedTemporaryFile()
+ filename = os.path.realpath(tmp_file.name)
+ tmp_file.close()
+
+ fake_loader = DictDataLoader({filename: 'sdfadf'})
+
+ self.assertRaisesRegex(errors.AnsibleError,
+ '.*The vault password file %s was not found.*' % filename,
+ vault.get_file_vault_secret,
+ filename=filename,
+ loader=fake_loader)
+
+
+class TestVaultIsEncrypted(unittest.TestCase):
+ def test_bytes_not_encrypted(self):
+ b_data = b"foobar"
+ self.assertFalse(vault.is_encrypted(b_data))
+
+ def test_bytes_encrypted(self):
+ b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
+ self.assertTrue(vault.is_encrypted(b_data))
+
+ def test_text_not_encrypted(self):
+ b_data = to_text(b"foobar")
+ self.assertFalse(vault.is_encrypted(b_data))
+
+ def test_text_encrypted(self):
+ b_data = to_text(b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible"))
+ self.assertTrue(vault.is_encrypted(b_data))
+
+ def test_invalid_text_not_ascii(self):
+ data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ "
+ self.assertFalse(vault.is_encrypted(data))
+
+ def test_invalid_bytes_not_ascii(self):
+ data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ "
+ b_data = to_bytes(data, encoding='utf-8')
+ self.assertFalse(vault.is_encrypted(b_data))
+
+
+class TestVaultIsEncryptedFile(unittest.TestCase):
+ def test_binary_file_handle_not_encrypted(self):
+ b_data = b"foobar"
+ b_data_fo = io.BytesIO(b_data)
+ self.assertFalse(vault.is_encrypted_file(b_data_fo))
+
+ def test_text_file_handle_not_encrypted(self):
+ data = u"foobar"
+ data_fo = io.StringIO(data)
+ self.assertFalse(vault.is_encrypted_file(data_fo))
+
+ def test_binary_file_handle_encrypted(self):
+ b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
+ b_data_fo = io.BytesIO(b_data)
+ self.assertTrue(vault.is_encrypted_file(b_data_fo))
+
+ def test_text_file_handle_encrypted(self):
+ data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % to_text(hexlify(b"ansible"))
+ data_fo = io.StringIO(data)
+ self.assertTrue(vault.is_encrypted_file(data_fo))
+
+ def test_binary_file_handle_invalid(self):
+ data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ "
+ b_data = to_bytes(data)
+ b_data_fo = io.BytesIO(b_data)
+ self.assertFalse(vault.is_encrypted_file(b_data_fo))
+
+ def test_text_file_handle_invalid(self):
+ data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ "
+ data_fo = io.StringIO(data)
+ self.assertFalse(vault.is_encrypted_file(data_fo))
+
+ def test_file_already_read_from_finds_header(self):
+ b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos")
+ b_data_fo = io.BytesIO(b_data)
+ b_data_fo.read(42) # Arbitrary number
+ self.assertTrue(vault.is_encrypted_file(b_data_fo))
+
+ def test_file_already_read_from_saves_file_pos(self):
+ b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos")
+ b_data_fo = io.BytesIO(b_data)
+ b_data_fo.read(69) # Arbitrary number
+ vault.is_encrypted_file(b_data_fo)
+ self.assertEqual(b_data_fo.tell(), 69)
+
+ def test_file_with_offset(self):
+ b_data = b"JUNK$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos")
+ b_data_fo = io.BytesIO(b_data)
+ self.assertTrue(vault.is_encrypted_file(b_data_fo, start_pos=4))
+
+ def test_file_with_count(self):
+ b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos")
+ vault_length = len(b_data)
+ b_data = b_data + u'ァ ア'.encode('utf-8')
+ b_data_fo = io.BytesIO(b_data)
+ self.assertTrue(vault.is_encrypted_file(b_data_fo, count=vault_length))
+
+ def test_file_with_offset_and_count(self):
+ b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos")
+ vault_length = len(b_data)
+ b_data = b'JUNK' + b_data + u'ァ ア'.encode('utf-8')
+ b_data_fo = io.BytesIO(b_data)
+ self.assertTrue(vault.is_encrypted_file(b_data_fo, start_pos=4, count=vault_length))
+
+
+@pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY,
+ reason="Skipping cryptography tests because cryptography is not installed")
+class TestVaultCipherAes256(unittest.TestCase):
+ def setUp(self):
+ self.vault_cipher = vault.VaultAES256()
+
+ def test(self):
+ self.assertIsInstance(self.vault_cipher, vault.VaultAES256)
+
+ # TODO: tag these as slow tests
+ def test_create_key_cryptography(self):
+ b_password = b'hunter42'
+ b_salt = os.urandom(32)
+ b_key_cryptography = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16)
+ self.assertIsInstance(b_key_cryptography, six.binary_type)
+
+ def test_create_key_known_cryptography(self):
+ b_password = b'hunter42'
+
+ # A fixed salt
+ b_salt = b'q' * 32 # q is the most random letter.
+ b_key_1 = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16)
+ self.assertIsInstance(b_key_1, six.binary_type)
+
+ # verify we get the same answer
+ # we could potentially run a few iterations of this and time it to see if it's roughly constant time
+ # and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI
+ b_key_2 = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16)
+ self.assertIsInstance(b_key_2, six.binary_type)
+ self.assertEqual(b_key_1, b_key_2)
+
+ def test_is_equal_is_equal(self):
+ self.assertTrue(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwxyz'))
+
+ def test_is_equal_unequal_length(self):
+ self.assertFalse(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwx and sometimes y'))
+
+ def test_is_equal_not_equal(self):
+ self.assertFalse(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'AbcdefghijKlmnopQrstuvwxZ'))
+
+ def test_is_equal_empty(self):
+ self.assertTrue(self.vault_cipher._is_equal(b'', b''))
+
+ def test_is_equal_non_ascii_equal(self):
+ utf8_data = to_bytes(u'私はガラスを食べられます。それは私を傷つけません。')
+ self.assertTrue(self.vault_cipher._is_equal(utf8_data, utf8_data))
+
+ def test_is_equal_non_ascii_unequal(self):
+ utf8_data = to_bytes(u'私はガラスを食べられます。それは私を傷つけません。')
+ utf8_data2 = to_bytes(u'Pot să mănânc sticlă și ea nu mă rănește.')
+
+ # Test for the len optimization path
+ self.assertFalse(self.vault_cipher._is_equal(utf8_data, utf8_data2))
+ # Test for the slower, char by char comparison path
+ self.assertFalse(self.vault_cipher._is_equal(utf8_data, utf8_data[:-1] + b'P'))
+
+ def test_is_equal_non_bytes(self):
+ """ Anything not a byte string should raise a TypeError """
+ self.assertRaises(TypeError, self.vault_cipher._is_equal, u"One fish", b"two fish")
+ self.assertRaises(TypeError, self.vault_cipher._is_equal, b"One fish", u"two fish")
+ self.assertRaises(TypeError, self.vault_cipher._is_equal, 1, b"red fish")
+ self.assertRaises(TypeError, self.vault_cipher._is_equal, b"blue fish", 2)
+
+
+class TestMatchSecrets(unittest.TestCase):
+ def test_empty_tuple(self):
+ secrets = [tuple()]
+ vault_ids = ['vault_id_1']
+ self.assertRaises(ValueError,
+ vault.match_secrets,
+ secrets, vault_ids)
+
+ def test_empty_secrets(self):
+ matches = vault.match_secrets([], ['vault_id_1'])
+ self.assertEqual(matches, [])
+
+ def test_single_match(self):
+ secret = TextVaultSecret('password')
+ matches = vault.match_secrets([('default', secret)], ['default'])
+ self.assertEqual(matches, [('default', secret)])
+
+ def test_no_matches(self):
+ secret = TextVaultSecret('password')
+ matches = vault.match_secrets([('default', secret)], ['not_default'])
+ self.assertEqual(matches, [])
+
+ def test_multiple_matches(self):
+ secrets = [('vault_id1', TextVaultSecret('password1')),
+ ('vault_id2', TextVaultSecret('password2')),
+ ('vault_id1', TextVaultSecret('password3')),
+ ('vault_id4', TextVaultSecret('password4'))]
+ vault_ids = ['vault_id1', 'vault_id4']
+ matches = vault.match_secrets(secrets, vault_ids)
+
+ self.assertEqual(len(matches), 3)
+ expected = [('vault_id1', TextVaultSecret('password1')),
+ ('vault_id1', TextVaultSecret('password3')),
+ ('vault_id4', TextVaultSecret('password4'))]
+ self.assertEqual([x for x, y in matches],
+ [a for a, b in expected])
+
+
+@pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY,
+ reason="Skipping cryptography tests because cryptography is not installed")
+class TestVaultLib(unittest.TestCase):
+ def setUp(self):
+ self.vault_password = "test-vault-password"
+ text_secret = TextVaultSecret(self.vault_password)
+ self.vault_secrets = [('default', text_secret),
+ ('test_id', text_secret)]
+ self.v = vault.VaultLib(self.vault_secrets)
+
+ def _vault_secrets(self, vault_id, secret):
+ return [(vault_id, secret)]
+
+ def _vault_secrets_from_password(self, vault_id, password):
+ return [(vault_id, TextVaultSecret(password))]
+
+ def test_encrypt(self):
+ plaintext = u'Some text to encrypt in a café'
+ b_vaulttext = self.v.encrypt(plaintext)
+
+ self.assertIsInstance(b_vaulttext, six.binary_type)
+
+ b_header = b'$ANSIBLE_VAULT;1.1;AES256\n'
+ self.assertEqual(b_vaulttext[:len(b_header)], b_header)
+
+ def test_encrypt_vault_id(self):
+ plaintext = u'Some text to encrypt in a café'
+ b_vaulttext = self.v.encrypt(plaintext, vault_id='test_id')
+
+ self.assertIsInstance(b_vaulttext, six.binary_type)
+
+ b_header = b'$ANSIBLE_VAULT;1.2;AES256;test_id\n'
+ self.assertEqual(b_vaulttext[:len(b_header)], b_header)
+
+ def test_encrypt_bytes(self):
+
+ plaintext = to_bytes(u'Some text to encrypt in a café')
+ b_vaulttext = self.v.encrypt(plaintext)
+
+ self.assertIsInstance(b_vaulttext, six.binary_type)
+
+ b_header = b'$ANSIBLE_VAULT;1.1;AES256\n'
+ self.assertEqual(b_vaulttext[:len(b_header)], b_header)
+
+ def test_encrypt_no_secret_empty_secrets(self):
+ vault_secrets = []
+ v = vault.VaultLib(vault_secrets)
+
+ plaintext = u'Some text to encrypt in a café'
+ self.assertRaisesRegex(vault.AnsibleVaultError,
+ '.*A vault password must be specified to encrypt data.*',
+ v.encrypt,
+ plaintext)
+
+ def test_format_vaulttext_envelope(self):
+ cipher_name = "TEST"
+ b_ciphertext = b"ansible"
+ b_vaulttext = vault.format_vaulttext_envelope(b_ciphertext,
+ cipher_name,
+ version=self.v.b_version,
+ vault_id='default')
+ b_lines = b_vaulttext.split(b'\n')
+ self.assertGreater(len(b_lines), 1, msg="failed to properly add header")
+
+ b_header = b_lines[0]
+ # self.assertTrue(b_header.endswith(b';TEST'), msg="header does not end with cipher name")
+
+ b_header_parts = b_header.split(b';')
+ self.assertEqual(len(b_header_parts), 4, msg="header has the wrong number of parts")
+ self.assertEqual(b_header_parts[0], b'$ANSIBLE_VAULT', msg="header does not start with $ANSIBLE_VAULT")
+ self.assertEqual(b_header_parts[1], self.v.b_version, msg="header version is incorrect")
+ self.assertEqual(b_header_parts[2], b'TEST', msg="header does not end with cipher name")
+
+ # And just to verify, lets parse the results and compare
+ b_ciphertext2, b_version2, cipher_name2, vault_id2 = \
+ vault.parse_vaulttext_envelope(b_vaulttext)
+ self.assertEqual(b_ciphertext, b_ciphertext2)
+ self.assertEqual(self.v.b_version, b_version2)
+ self.assertEqual(cipher_name, cipher_name2)
+ self.assertEqual('default', vault_id2)
+
+ def test_parse_vaulttext_envelope(self):
+ b_vaulttext = b"$ANSIBLE_VAULT;9.9;TEST\nansible"
+ b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext)
+ b_lines = b_ciphertext.split(b'\n')
+ self.assertEqual(b_lines[0], b"ansible", msg="Payload was not properly split from the header")
+ self.assertEqual(cipher_name, u'TEST', msg="cipher name was not properly set")
+ self.assertEqual(b_version, b"9.9", msg="version was not properly set")
+
+ def test_parse_vaulttext_envelope_crlf(self):
+ b_vaulttext = b"$ANSIBLE_VAULT;9.9;TEST\r\nansible"
+ b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext)
+ b_lines = b_ciphertext.split(b'\n')
+ self.assertEqual(b_lines[0], b"ansible", msg="Payload was not properly split from the header")
+ self.assertEqual(cipher_name, u'TEST', msg="cipher name was not properly set")
+ self.assertEqual(b_version, b"9.9", msg="version was not properly set")
+
+ def test_encrypt_decrypt_aes256(self):
+ self.v.cipher_name = u'AES256'
+ plaintext = u"foobar"
+ b_vaulttext = self.v.encrypt(plaintext)
+ b_plaintext = self.v.decrypt(b_vaulttext)
+ self.assertNotEqual(b_vaulttext, b"foobar", msg="encryption failed")
+ self.assertEqual(b_plaintext, b"foobar", msg="decryption failed")
+
+ def test_encrypt_decrypt_aes256_none_secrets(self):
+ vault_secrets = self._vault_secrets_from_password('default', 'ansible')
+ v = vault.VaultLib(vault_secrets)
+
+ plaintext = u"foobar"
+ b_vaulttext = v.encrypt(plaintext)
+
+ # VaultLib will default to empty {} if secrets is None
+ v_none = vault.VaultLib(None)
+ # so set secrets None explicitly
+ v_none.secrets = None
+ self.assertRaisesRegex(vault.AnsibleVaultError,
+ '.*A vault password must be specified to decrypt data.*',
+ v_none.decrypt,
+ b_vaulttext)
+
+ def test_encrypt_decrypt_aes256_empty_secrets(self):
+ vault_secrets = self._vault_secrets_from_password('default', 'ansible')
+ v = vault.VaultLib(vault_secrets)
+
+ plaintext = u"foobar"
+ b_vaulttext = v.encrypt(plaintext)
+
+ vault_secrets_empty = []
+ v_none = vault.VaultLib(vault_secrets_empty)
+
+ self.assertRaisesRegex(vault.AnsibleVaultError,
+ '.*Attempting to decrypt but no vault secrets found.*',
+ v_none.decrypt,
+ b_vaulttext)
+
+ def test_encrypt_decrypt_aes256_multiple_secrets_all_wrong(self):
+ plaintext = u'Some text to encrypt in a café'
+ b_vaulttext = self.v.encrypt(plaintext)
+
+ vault_secrets = [('default', TextVaultSecret('another-wrong-password')),
+ ('wrong-password', TextVaultSecret('wrong-password'))]
+
+ v_multi = vault.VaultLib(vault_secrets)
+ self.assertRaisesRegex(errors.AnsibleError,
+ '.*Decryption failed.*',
+ v_multi.decrypt,
+ b_vaulttext,
+ filename='/dev/null/fake/filename')
+
+ def test_encrypt_decrypt_aes256_multiple_secrets_one_valid(self):
+ plaintext = u'Some text to encrypt in a café'
+ b_vaulttext = self.v.encrypt(plaintext)
+
+ correct_secret = TextVaultSecret(self.vault_password)
+ wrong_secret = TextVaultSecret('wrong-password')
+
+ vault_secrets = [('default', wrong_secret),
+ ('corect_secret', correct_secret),
+ ('wrong_secret', wrong_secret)]
+
+ v_multi = vault.VaultLib(vault_secrets)
+ b_plaintext = v_multi.decrypt(b_vaulttext)
+ self.assertNotEqual(b_vaulttext, to_bytes(plaintext), msg="encryption failed")
+ self.assertEqual(b_plaintext, to_bytes(plaintext), msg="decryption failed")
+
+ def test_encrypt_decrypt_aes256_existing_vault(self):
+ self.v.cipher_name = u'AES256'
+ b_orig_plaintext = b"Setec Astronomy"
+ vaulttext = u'''$ANSIBLE_VAULT;1.1;AES256
+33363965326261303234626463623963633531343539616138316433353830356566396130353436
+3562643163366231316662386565383735653432386435610a306664636137376132643732393835
+63383038383730306639353234326630666539346233376330303938323639306661313032396437
+6233623062366136310a633866373936313238333730653739323461656662303864663666653563
+3138'''
+
+ b_plaintext = self.v.decrypt(vaulttext)
+ self.assertEqual(b_plaintext, b_plaintext, msg="decryption failed")
+
+ b_vaulttext = to_bytes(vaulttext, encoding='ascii', errors='strict')
+ b_plaintext = self.v.decrypt(b_vaulttext)
+ self.assertEqual(b_plaintext, b_orig_plaintext, msg="decryption failed")
+
+ # FIXME This test isn't working quite yet.
+ @pytest.mark.skip(reason='This test is not ready yet')
+ def test_encrypt_decrypt_aes256_bad_hmac(self):
+
+ self.v.cipher_name = 'AES256'
+ # plaintext = "Setec Astronomy"
+ enc_data = '''$ANSIBLE_VAULT;1.1;AES256
+33363965326261303234626463623963633531343539616138316433353830356566396130353436
+3562643163366231316662386565383735653432386435610a306664636137376132643732393835
+63383038383730306639353234326630666539346233376330303938323639306661313032396437
+6233623062366136310a633866373936313238333730653739323461656662303864663666653563
+3138'''
+ b_data = to_bytes(enc_data, errors='strict', encoding='utf-8')
+ b_data = self.v._split_header(b_data)
+ foo = binascii.unhexlify(b_data)
+ lines = foo.splitlines()
+ # line 0 is salt, line 1 is hmac, line 2+ is ciphertext
+ b_salt = lines[0]
+ b_hmac = lines[1]
+ b_ciphertext_data = b'\n'.join(lines[2:])
+
+ b_ciphertext = binascii.unhexlify(b_ciphertext_data)
+ # b_orig_ciphertext = b_ciphertext[:]
+
+ # now muck with the text
+ # b_munged_ciphertext = b_ciphertext[:10] + b'\x00' + b_ciphertext[11:]
+ # b_munged_ciphertext = b_ciphertext
+ # assert b_orig_ciphertext != b_munged_ciphertext
+
+ b_ciphertext_data = binascii.hexlify(b_ciphertext)
+ b_payload = b'\n'.join([b_salt, b_hmac, b_ciphertext_data])
+ # reformat
+ b_invalid_ciphertext = self.v._format_output(b_payload)
+
+ # assert we throw an error
+ self.v.decrypt(b_invalid_ciphertext)
+
+ def test_decrypt_and_get_vault_id(self):
+ b_expected_plaintext = to_bytes('foo bar\n')
+ vaulttext = '''$ANSIBLE_VAULT;1.2;AES256;ansible_devel
+65616435333934613466373335363332373764363365633035303466643439313864663837393234
+3330656363343637313962633731333237313636633534630a386264363438363362326132363239
+39363166646664346264383934393935653933316263333838386362633534326664646166663736
+6462303664383765650a356637643633366663643566353036303162386237336233393065393164
+6264'''
+
+ vault_secrets = self._vault_secrets_from_password('ansible_devel', 'ansible')
+ v = vault.VaultLib(vault_secrets)
+
+ b_vaulttext = to_bytes(vaulttext)
+
+ b_plaintext, vault_id_used, vault_secret_used = v.decrypt_and_get_vault_id(b_vaulttext)
+
+ self.assertEqual(b_expected_plaintext, b_plaintext)
+ self.assertEqual(vault_id_used, 'ansible_devel')
+ self.assertEqual(vault_secret_used.text, 'ansible')
+
+ def test_decrypt_non_default_1_2(self):
+ b_expected_plaintext = to_bytes('foo bar\n')
+ vaulttext = '''$ANSIBLE_VAULT;1.2;AES256;ansible_devel
+65616435333934613466373335363332373764363365633035303466643439313864663837393234
+3330656363343637313962633731333237313636633534630a386264363438363362326132363239
+39363166646664346264383934393935653933316263333838386362633534326664646166663736
+6462303664383765650a356637643633366663643566353036303162386237336233393065393164
+6264'''
+
+ vault_secrets = self._vault_secrets_from_password('default', 'ansible')
+ v = vault.VaultLib(vault_secrets)
+
+ b_vaulttext = to_bytes(vaulttext)
+
+ b_plaintext = v.decrypt(b_vaulttext)
+ self.assertEqual(b_expected_plaintext, b_plaintext)
+
+ b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext)
+ self.assertEqual('ansible_devel', vault_id)
+ self.assertEqual(b'1.2', b_version)
+
+ def test_decrypt_decrypted(self):
+ plaintext = u"ansible"
+ self.assertRaises(errors.AnsibleError, self.v.decrypt, plaintext)
+
+ b_plaintext = b"ansible"
+ self.assertRaises(errors.AnsibleError, self.v.decrypt, b_plaintext)
+
+ def test_cipher_not_set(self):
+ plaintext = u"ansible"
+ self.v.encrypt(plaintext)
+ self.assertEqual(self.v.cipher_name, "AES256")
diff --git a/test/units/parsing/vault/test_vault_editor.py b/test/units/parsing/vault/test_vault_editor.py
new file mode 100644
index 0000000..77509f0
--- /dev/null
+++ b/test/units/parsing/vault/test_vault_editor.py
@@ -0,0 +1,521 @@
+# (c) 2014, James Tanner <tanner.jc@gmail.com>
+# (c) 2014, James Cammarata, <jcammarata@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+import tempfile
+from io import BytesIO, StringIO
+
+import pytest
+
+from units.compat import unittest
+from unittest.mock import patch
+
+from ansible import errors
+from ansible.parsing import vault
+from ansible.parsing.vault import VaultLib, VaultEditor, match_encrypt_secret
+
+from ansible.module_utils.six import PY3
+from ansible.module_utils._text import to_bytes, to_text
+
+from units.mock.vault_helper import TextVaultSecret
+
+v11_data = """$ANSIBLE_VAULT;1.1;AES256
+62303130653266653331306264616235333735323636616539316433666463323964623162386137
+3961616263373033353631316333623566303532663065310a393036623466376263393961326530
+64336561613965383835646464623865663966323464653236343638373165343863623638316664
+3631633031323837340a396530313963373030343933616133393566366137363761373930663833
+3739"""
+
+
+@pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY,
+ reason="Skipping cryptography tests because cryptography is not installed")
+class TestVaultEditor(unittest.TestCase):
+
+ def setUp(self):
+ self._test_dir = None
+ self.vault_password = "test-vault-password"
+ vault_secret = TextVaultSecret(self.vault_password)
+ self.vault_secrets = [('vault_secret', vault_secret),
+ ('default', vault_secret)]
+
+ @property
+ def vault_secret(self):
+ return match_encrypt_secret(self.vault_secrets)[1]
+
+ def tearDown(self):
+ if self._test_dir:
+ pass
+ # shutil.rmtree(self._test_dir)
+ self._test_dir = None
+
+ def _secrets(self, password):
+ vault_secret = TextVaultSecret(password)
+ vault_secrets = [('default', vault_secret)]
+ return vault_secrets
+
+ def test_methods_exist(self):
+ v = vault.VaultEditor(None)
+ slots = ['create_file',
+ 'decrypt_file',
+ 'edit_file',
+ 'encrypt_file',
+ 'rekey_file',
+ 'read_data',
+ 'write_data']
+ for slot in slots:
+ assert hasattr(v, slot), "VaultLib is missing the %s method" % slot
+
+ def _create_test_dir(self):
+ suffix = '_ansible_unit_test_%s_' % (self.__class__.__name__)
+ return tempfile.mkdtemp(suffix=suffix)
+
+ def _create_file(self, test_dir, name, content=None, symlink=False):
+ file_path = os.path.join(test_dir, name)
+ opened_file = open(file_path, 'wb')
+ if content:
+ opened_file.write(content)
+ opened_file.close()
+ return file_path
+
+ def _vault_editor(self, vault_secrets=None):
+ if vault_secrets is None:
+ vault_secrets = self._secrets(self.vault_password)
+ return VaultEditor(VaultLib(vault_secrets))
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_helper_empty_target(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+
+ src_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ mock_sp_call.side_effect = self._faux_command
+ ve = self._vault_editor()
+
+ b_ciphertext = ve._edit_file_helper(src_file_path, self.vault_secret)
+
+ self.assertNotEqual(src_contents, b_ciphertext)
+
+ def test_stdin_binary(self):
+ stdin_data = '\0'
+
+ if PY3:
+ fake_stream = StringIO(stdin_data)
+ fake_stream.buffer = BytesIO(to_bytes(stdin_data))
+ else:
+ fake_stream = BytesIO(to_bytes(stdin_data))
+
+ with patch('sys.stdin', fake_stream):
+ ve = self._vault_editor()
+ data = ve.read_data('-')
+
+ self.assertEqual(data, b'\0')
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_helper_call_exception(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+
+ src_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ error_txt = 'calling editor raised an exception'
+ mock_sp_call.side_effect = errors.AnsibleError(error_txt)
+
+ ve = self._vault_editor()
+
+ self.assertRaisesRegex(errors.AnsibleError,
+ error_txt,
+ ve._edit_file_helper,
+ src_file_path,
+ self.vault_secret)
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_helper_symlink_target(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ src_file_link_path = os.path.join(self._test_dir, 'a_link_to_dest_file')
+
+ os.symlink(src_file_path, src_file_link_path)
+
+ mock_sp_call.side_effect = self._faux_command
+ ve = self._vault_editor()
+
+ b_ciphertext = ve._edit_file_helper(src_file_link_path, self.vault_secret)
+
+ self.assertNotEqual(src_file_contents, b_ciphertext,
+ 'b_ciphertext should be encrypted and not equal to src_contents')
+
+ def _faux_editor(self, editor_args, new_src_contents=None):
+ if editor_args[0] == 'shred':
+ return
+
+ tmp_path = editor_args[-1]
+
+ # simulate the tmp file being editted
+ tmp_file = open(tmp_path, 'wb')
+ if new_src_contents:
+ tmp_file.write(new_src_contents)
+ tmp_file.close()
+
+ def _faux_command(self, tmp_path):
+ pass
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_helper_no_change(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ # editor invocation doesn't change anything
+ def faux_editor(editor_args):
+ self._faux_editor(editor_args, src_file_contents)
+
+ mock_sp_call.side_effect = faux_editor
+ ve = self._vault_editor()
+
+ ve._edit_file_helper(src_file_path, self.vault_secret, existing_data=src_file_contents)
+
+ new_target_file = open(src_file_path, 'rb')
+ new_target_file_contents = new_target_file.read()
+ self.assertEqual(src_file_contents, new_target_file_contents)
+
+ def _assert_file_is_encrypted(self, vault_editor, src_file_path, src_contents):
+ new_src_file = open(src_file_path, 'rb')
+ new_src_file_contents = new_src_file.read()
+
+ # TODO: assert that it is encrypted
+ self.assertTrue(vault.is_encrypted(new_src_file_contents))
+
+ src_file_plaintext = vault_editor.vault.decrypt(new_src_file_contents)
+
+ # the plaintext should not be encrypted
+ self.assertFalse(vault.is_encrypted(src_file_plaintext))
+
+ # and the new plaintext should match the original
+ self.assertEqual(src_file_plaintext, src_contents)
+
+ def _assert_file_is_link(self, src_file_link_path, src_file_path):
+ self.assertTrue(os.path.islink(src_file_link_path),
+ 'The dest path (%s) should be a symlink to (%s) but is not' % (src_file_link_path, src_file_path))
+
+ def test_rekey_file(self):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ ve = self._vault_editor()
+ ve.encrypt_file(src_file_path, self.vault_secret)
+
+ # FIXME: update to just set self._secrets or just a new vault secret id
+ new_password = 'password2:electricbugaloo'
+ new_vault_secret = TextVaultSecret(new_password)
+ new_vault_secrets = [('default', new_vault_secret)]
+ ve.rekey_file(src_file_path, vault.match_encrypt_secret(new_vault_secrets)[1])
+
+ # FIXME: can just update self._secrets here
+ new_ve = vault.VaultEditor(VaultLib(new_vault_secrets))
+ self._assert_file_is_encrypted(new_ve, src_file_path, src_file_contents)
+
+ def test_rekey_file_no_new_password(self):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ ve = self._vault_editor()
+ ve.encrypt_file(src_file_path, self.vault_secret)
+
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'The value for the new_password to rekey',
+ ve.rekey_file,
+ src_file_path,
+ None)
+
+ def test_rekey_file_not_encrypted(self):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ ve = self._vault_editor()
+
+ new_password = 'password2:electricbugaloo'
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'input is not vault encrypted data',
+ ve.rekey_file,
+ src_file_path, new_password)
+
+ def test_plaintext(self):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ ve = self._vault_editor()
+ ve.encrypt_file(src_file_path, self.vault_secret)
+
+ res = ve.plaintext(src_file_path)
+ self.assertEqual(src_file_contents, res)
+
+ def test_plaintext_not_encrypted(self):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ ve = self._vault_editor()
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'input is not vault encrypted data',
+ ve.plaintext,
+ src_file_path)
+
+ def test_encrypt_file(self):
+ self._test_dir = self._create_test_dir()
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ ve = self._vault_editor()
+ ve.encrypt_file(src_file_path, self.vault_secret)
+
+ self._assert_file_is_encrypted(ve, src_file_path, src_file_contents)
+
+ def test_encrypt_file_symlink(self):
+ self._test_dir = self._create_test_dir()
+
+ src_file_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents)
+
+ src_file_link_path = os.path.join(self._test_dir, 'a_link_to_dest_file')
+ os.symlink(src_file_path, src_file_link_path)
+
+ ve = self._vault_editor()
+ ve.encrypt_file(src_file_link_path, self.vault_secret)
+
+ self._assert_file_is_encrypted(ve, src_file_path, src_file_contents)
+ self._assert_file_is_encrypted(ve, src_file_link_path, src_file_contents)
+
+ self._assert_file_is_link(src_file_link_path, src_file_path)
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_no_vault_id(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+ src_contents = to_bytes("some info in a file\nyup.")
+
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ new_src_contents = to_bytes("The info is different now.")
+
+ def faux_editor(editor_args):
+ self._faux_editor(editor_args, new_src_contents)
+
+ mock_sp_call.side_effect = faux_editor
+
+ ve = self._vault_editor()
+
+ ve.encrypt_file(src_file_path, self.vault_secret)
+ ve.edit_file(src_file_path)
+
+ new_src_file = open(src_file_path, 'rb')
+ new_src_file_contents = new_src_file.read()
+
+ self.assertTrue(b'$ANSIBLE_VAULT;1.1;AES256' in new_src_file_contents)
+
+ src_file_plaintext = ve.vault.decrypt(new_src_file_contents)
+ self.assertEqual(src_file_plaintext, new_src_contents)
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_with_vault_id(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+ src_contents = to_bytes("some info in a file\nyup.")
+
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ new_src_contents = to_bytes("The info is different now.")
+
+ def faux_editor(editor_args):
+ self._faux_editor(editor_args, new_src_contents)
+
+ mock_sp_call.side_effect = faux_editor
+
+ ve = self._vault_editor()
+
+ ve.encrypt_file(src_file_path, self.vault_secret,
+ vault_id='vault_secrets')
+ ve.edit_file(src_file_path)
+
+ new_src_file = open(src_file_path, 'rb')
+ new_src_file_contents = new_src_file.read()
+
+ self.assertTrue(b'$ANSIBLE_VAULT;1.2;AES256;vault_secrets' in new_src_file_contents)
+
+ src_file_plaintext = ve.vault.decrypt(new_src_file_contents)
+ self.assertEqual(src_file_plaintext, new_src_contents)
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_symlink(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+ src_contents = to_bytes("some info in a file\nyup.")
+
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ new_src_contents = to_bytes("The info is different now.")
+
+ def faux_editor(editor_args):
+ self._faux_editor(editor_args, new_src_contents)
+
+ mock_sp_call.side_effect = faux_editor
+
+ ve = self._vault_editor()
+
+ ve.encrypt_file(src_file_path, self.vault_secret)
+
+ src_file_link_path = os.path.join(self._test_dir, 'a_link_to_dest_file')
+
+ os.symlink(src_file_path, src_file_link_path)
+
+ ve.edit_file(src_file_link_path)
+
+ new_src_file = open(src_file_path, 'rb')
+ new_src_file_contents = new_src_file.read()
+
+ src_file_plaintext = ve.vault.decrypt(new_src_file_contents)
+
+ self._assert_file_is_link(src_file_link_path, src_file_path)
+
+ self.assertEqual(src_file_plaintext, new_src_contents)
+
+ # self.assertEqual(src_file_plaintext, new_src_contents,
+ # 'The decrypted plaintext of the editted file is not the expected contents.')
+
+ @patch('ansible.parsing.vault.subprocess.call')
+ def test_edit_file_not_encrypted(self, mock_sp_call):
+ self._test_dir = self._create_test_dir()
+ src_contents = to_bytes("some info in a file\nyup.")
+
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ new_src_contents = to_bytes("The info is different now.")
+
+ def faux_editor(editor_args):
+ self._faux_editor(editor_args, new_src_contents)
+
+ mock_sp_call.side_effect = faux_editor
+
+ ve = self._vault_editor()
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'input is not vault encrypted data',
+ ve.edit_file,
+ src_file_path)
+
+ def test_create_file_exists(self):
+ self._test_dir = self._create_test_dir()
+ src_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ ve = self._vault_editor()
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'please use .edit. instead',
+ ve.create_file,
+ src_file_path,
+ self.vault_secret)
+
+ def test_decrypt_file_exception(self):
+ self._test_dir = self._create_test_dir()
+ src_contents = to_bytes("some info in a file\nyup.")
+ src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
+
+ ve = self._vault_editor()
+ self.assertRaisesRegex(errors.AnsibleError,
+ 'input is not vault encrypted data',
+ ve.decrypt_file,
+ src_file_path)
+
+ @patch.object(vault.VaultEditor, '_editor_shell_command')
+ def test_create_file(self, mock_editor_shell_command):
+
+ def sc_side_effect(filename):
+ return ['touch', filename]
+ mock_editor_shell_command.side_effect = sc_side_effect
+
+ tmp_file = tempfile.NamedTemporaryFile()
+ os.unlink(tmp_file.name)
+
+ _secrets = self._secrets('ansible')
+ ve = self._vault_editor(_secrets)
+ ve.create_file(tmp_file.name, vault.match_encrypt_secret(_secrets)[1])
+
+ self.assertTrue(os.path.exists(tmp_file.name))
+
+ def test_decrypt_1_1(self):
+ v11_file = tempfile.NamedTemporaryFile(delete=False)
+ with v11_file as f:
+ f.write(to_bytes(v11_data))
+
+ ve = self._vault_editor(self._secrets("ansible"))
+
+ # make sure the password functions for the cipher
+ error_hit = False
+ try:
+ ve.decrypt_file(v11_file.name)
+ except errors.AnsibleError:
+ error_hit = True
+
+ # verify decrypted content
+ f = open(v11_file.name, "rb")
+ fdata = to_text(f.read())
+ f.close()
+
+ os.unlink(v11_file.name)
+
+ assert error_hit is False, "error decrypting 1.1 file"
+ assert fdata.strip() == "foo", "incorrect decryption of 1.1 file: %s" % fdata.strip()
+
+ def test_real_path_dash(self):
+ filename = '-'
+ ve = self._vault_editor()
+
+ res = ve._real_path(filename)
+ self.assertEqual(res, '-')
+
+ def test_real_path_dev_null(self):
+ filename = '/dev/null'
+ ve = self._vault_editor()
+
+ res = ve._real_path(filename)
+ self.assertEqual(res, '/dev/null')
+
+ def test_real_path_symlink(self):
+ self._test_dir = os.path.realpath(self._create_test_dir())
+ file_path = self._create_file(self._test_dir, 'test_file', content=b'this is a test file')
+ file_link_path = os.path.join(self._test_dir, 'a_link_to_test_file')
+
+ os.symlink(file_path, file_link_path)
+
+ ve = self._vault_editor()
+
+ res = ve._real_path(file_link_path)
+ self.assertEqual(res, file_path)
diff --git a/test/units/parsing/yaml/__init__.py b/test/units/parsing/yaml/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/parsing/yaml/__init__.py
diff --git a/test/units/parsing/yaml/test_constructor.py b/test/units/parsing/yaml/test_constructor.py
new file mode 100644
index 0000000..717bf35
--- /dev/null
+++ b/test/units/parsing/yaml/test_constructor.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# (c) 2020 Matt Martz <matt@sivel.net>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+import pytest
+from yaml import MappingNode, Mark, ScalarNode
+from yaml.constructor import ConstructorError
+
+import ansible.constants as C
+from ansible.utils.display import Display
+from ansible.parsing.yaml.constructor import AnsibleConstructor
+
+
+@pytest.fixture
+def dupe_node():
+ tag = 'tag:yaml.org,2002:map'
+ scalar_tag = 'tag:yaml.org,2002:str'
+ mark = Mark(tag, 0, 0, 0, None, None)
+ node = MappingNode(
+ tag,
+ [
+ (
+ ScalarNode(tag=scalar_tag, value='bar', start_mark=mark),
+ ScalarNode(tag=scalar_tag, value='baz', start_mark=mark)
+ ),
+ (
+ ScalarNode(tag=scalar_tag, value='bar', start_mark=mark),
+ ScalarNode(tag=scalar_tag, value='qux', start_mark=mark)
+ ),
+ ],
+ start_mark=mark
+ )
+
+ return node
+
+
+class Capture:
+ def __init__(self):
+ self.called = False
+ self.calls = []
+
+ def __call__(self, *args, **kwargs):
+ self.called = True
+ self.calls.append((
+ args,
+ kwargs
+ ))
+
+
+def test_duplicate_yaml_dict_key_ignore(dupe_node, monkeypatch):
+ monkeypatch.setattr(C, 'DUPLICATE_YAML_DICT_KEY', 'ignore')
+ cap = Capture()
+ monkeypatch.setattr(Display(), 'warning', cap)
+ ac = AnsibleConstructor()
+ ac.construct_mapping(dupe_node)
+ assert not cap.called
+
+
+def test_duplicate_yaml_dict_key_warn(dupe_node, monkeypatch):
+ monkeypatch.setattr(C, 'DUPLICATE_YAML_DICT_KEY', 'warn')
+ cap = Capture()
+ monkeypatch.setattr(Display(), 'warning', cap)
+ ac = AnsibleConstructor()
+ ac.construct_mapping(dupe_node)
+ assert cap.called
+ expected = [
+ (
+ (
+ 'While constructing a mapping from tag:yaml.org,2002:map, line 1, column 1, '
+ 'found a duplicate dict key (bar). Using last defined value only.',
+ ),
+ {}
+ )
+ ]
+ assert cap.calls == expected
+
+
+def test_duplicate_yaml_dict_key_error(dupe_node, monkeypatch, mocker):
+ monkeypatch.setattr(C, 'DUPLICATE_YAML_DICT_KEY', 'error')
+ ac = AnsibleConstructor()
+ pytest.raises(ConstructorError, ac.construct_mapping, dupe_node)
diff --git a/test/units/parsing/yaml/test_dumper.py b/test/units/parsing/yaml/test_dumper.py
new file mode 100644
index 0000000..5fbc139
--- /dev/null
+++ b/test/units/parsing/yaml/test_dumper.py
@@ -0,0 +1,123 @@
+# coding: utf-8
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import io
+import yaml
+
+from jinja2.exceptions import UndefinedError
+
+from units.compat import unittest
+from ansible.parsing import vault
+from ansible.parsing.yaml import dumper, objects
+from ansible.parsing.yaml.loader import AnsibleLoader
+from ansible.module_utils.six import PY2
+from ansible.template import AnsibleUndefined
+from ansible.utils.unsafe_proxy import AnsibleUnsafeText, AnsibleUnsafeBytes
+
+from units.mock.yaml_helper import YamlTestUtils
+from units.mock.vault_helper import TextVaultSecret
+from ansible.vars.manager import VarsWithSources
+
+
+class TestAnsibleDumper(unittest.TestCase, YamlTestUtils):
+ def setUp(self):
+ self.vault_password = "hunter42"
+ vault_secret = TextVaultSecret(self.vault_password)
+ self.vault_secrets = [('vault_secret', vault_secret)]
+ self.good_vault = vault.VaultLib(self.vault_secrets)
+ self.vault = self.good_vault
+ self.stream = self._build_stream()
+ self.dumper = dumper.AnsibleDumper
+
+ def _build_stream(self, yaml_text=None):
+ text = yaml_text or u''
+ stream = io.StringIO(text)
+ return stream
+
+ def _loader(self, stream):
+ return AnsibleLoader(stream, vault_secrets=self.vault.secrets)
+
+ def test_ansible_vault_encrypted_unicode(self):
+ plaintext = 'This is a string we are going to encrypt.'
+ avu = objects.AnsibleVaultEncryptedUnicode.from_plaintext(plaintext, vault=self.vault,
+ secret=vault.match_secrets(self.vault_secrets, ['vault_secret'])[0][1])
+
+ yaml_out = self._dump_string(avu, dumper=self.dumper)
+ stream = self._build_stream(yaml_out)
+ loader = self._loader(stream)
+
+ data_from_yaml = loader.get_single_data()
+
+ self.assertEqual(plaintext, data_from_yaml.data)
+
+ def test_bytes(self):
+ b_text = u'tréma'.encode('utf-8')
+ unsafe_object = AnsibleUnsafeBytes(b_text)
+ yaml_out = self._dump_string(unsafe_object, dumper=self.dumper)
+
+ stream = self._build_stream(yaml_out)
+ loader = self._loader(stream)
+
+ data_from_yaml = loader.get_single_data()
+
+ result = b_text
+ if PY2:
+ # https://pyyaml.org/wiki/PyYAMLDocumentation#string-conversion-python-2-only
+ # pyyaml on Python 2 can return either unicode or bytes when given byte strings.
+ # We normalize that to always return unicode on Python2 as that's right most of the
+ # time. However, this means byte strings can round trip through yaml on Python3 but
+ # not on Python2. To make this code work the same on Python2 and Python3 (we want
+ # the Python3 behaviour) we need to change the methods in Ansible to:
+ # (1) Let byte strings pass through yaml without being converted on Python2
+ # (2) Convert byte strings to text strings before being given to pyyaml (Without this,
+ # strings would end up as byte strings most of the time which would mostly be wrong)
+ # In practice, we mostly read bytes in from files and then pass that to pyyaml, for which
+ # the present behavior is correct.
+ # This is a workaround for the current behavior.
+ result = u'tr\xe9ma'
+
+ self.assertEqual(result, data_from_yaml)
+
+ def test_unicode(self):
+ u_text = u'nöel'
+ unsafe_object = AnsibleUnsafeText(u_text)
+ yaml_out = self._dump_string(unsafe_object, dumper=self.dumper)
+
+ stream = self._build_stream(yaml_out)
+ loader = self._loader(stream)
+
+ data_from_yaml = loader.get_single_data()
+
+ self.assertEqual(u_text, data_from_yaml)
+
+ def test_vars_with_sources(self):
+ try:
+ self._dump_string(VarsWithSources(), dumper=self.dumper)
+ except yaml.representer.RepresenterError:
+ self.fail("Dump VarsWithSources raised RepresenterError unexpectedly!")
+
+ def test_undefined(self):
+ undefined_object = AnsibleUndefined()
+ try:
+ yaml_out = self._dump_string(undefined_object, dumper=self.dumper)
+ except UndefinedError:
+ yaml_out = None
+
+ self.assertIsNone(yaml_out)
diff --git a/test/units/parsing/yaml/test_loader.py b/test/units/parsing/yaml/test_loader.py
new file mode 100644
index 0000000..117f80a
--- /dev/null
+++ b/test/units/parsing/yaml/test_loader.py
@@ -0,0 +1,432 @@
+# coding: utf-8
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from collections.abc import Sequence, Set, Mapping
+from io import StringIO
+
+from units.compat import unittest
+
+from ansible import errors
+from ansible.module_utils.six import text_type, binary_type
+from ansible.parsing.yaml.loader import AnsibleLoader
+from ansible.parsing import vault
+from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode
+from ansible.parsing.yaml.dumper import AnsibleDumper
+
+from units.mock.yaml_helper import YamlTestUtils
+from units.mock.vault_helper import TextVaultSecret
+
+from yaml.parser import ParserError
+from yaml.scanner import ScannerError
+
+
+class NameStringIO(StringIO):
+ """In py2.6, StringIO doesn't let you set name because a baseclass has it
+ as readonly property"""
+ name = None
+
+ def __init__(self, *args, **kwargs):
+ super(NameStringIO, self).__init__(*args, **kwargs)
+
+
+class TestAnsibleLoaderBasic(unittest.TestCase):
+
+ def test_parse_number(self):
+ stream = StringIO(u"""
+ 1
+ """)
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, 1)
+ # No line/column info saved yet
+
+ def test_parse_string(self):
+ stream = StringIO(u"""
+ Ansible
+ """)
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, u'Ansible')
+ self.assertIsInstance(data, text_type)
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
+
+ def test_parse_utf8_string(self):
+ stream = StringIO(u"""
+ Cafè Eñyei
+ """)
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, u'Cafè Eñyei')
+ self.assertIsInstance(data, text_type)
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
+
+ def test_parse_dict(self):
+ stream = StringIO(u"""
+ webster: daniel
+ oed: oxford
+ """)
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, {'webster': 'daniel', 'oed': 'oxford'})
+ self.assertEqual(len(data), 2)
+ self.assertIsInstance(list(data.keys())[0], text_type)
+ self.assertIsInstance(list(data.values())[0], text_type)
+
+ # Beginning of the first key
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
+
+ self.assertEqual(data[u'webster'].ansible_pos, ('myfile.yml', 2, 26))
+ self.assertEqual(data[u'oed'].ansible_pos, ('myfile.yml', 3, 22))
+
+ def test_parse_list(self):
+ stream = StringIO(u"""
+ - a
+ - b
+ """)
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, [u'a', u'b'])
+ self.assertEqual(len(data), 2)
+ self.assertIsInstance(data[0], text_type)
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
+
+ self.assertEqual(data[0].ansible_pos, ('myfile.yml', 2, 19))
+ self.assertEqual(data[1].ansible_pos, ('myfile.yml', 3, 19))
+
+ def test_parse_short_dict(self):
+ stream = StringIO(u"""{"foo": "bar"}""")
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, dict(foo=u'bar'))
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 1, 1))
+ self.assertEqual(data[u'foo'].ansible_pos, ('myfile.yml', 1, 9))
+
+ stream = StringIO(u"""foo: bar""")
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, dict(foo=u'bar'))
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 1, 1))
+ self.assertEqual(data[u'foo'].ansible_pos, ('myfile.yml', 1, 6))
+
+ def test_error_conditions(self):
+ stream = StringIO(u"""{""")
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ self.assertRaises(ParserError, loader.get_single_data)
+
+ def test_tab_error(self):
+ stream = StringIO(u"""---\nhosts: localhost\nvars:\n foo: bar\n\tblip: baz""")
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ self.assertRaises(ScannerError, loader.get_single_data)
+
+ def test_front_matter(self):
+ stream = StringIO(u"""---\nfoo: bar""")
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, dict(foo=u'bar'))
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 1))
+ self.assertEqual(data[u'foo'].ansible_pos, ('myfile.yml', 2, 6))
+
+ # Initial indent (See: #6348)
+ stream = StringIO(u""" - foo: bar\n baz: qux""")
+ loader = AnsibleLoader(stream, 'myfile.yml')
+ data = loader.get_single_data()
+ self.assertEqual(data, [{u'foo': u'bar', u'baz': u'qux'}])
+
+ self.assertEqual(data.ansible_pos, ('myfile.yml', 1, 2))
+ self.assertEqual(data[0].ansible_pos, ('myfile.yml', 1, 4))
+ self.assertEqual(data[0][u'foo'].ansible_pos, ('myfile.yml', 1, 9))
+ self.assertEqual(data[0][u'baz'].ansible_pos, ('myfile.yml', 2, 9))
+
+
+class TestAnsibleLoaderVault(unittest.TestCase, YamlTestUtils):
+ def setUp(self):
+ self.vault_password = "hunter42"
+ vault_secret = TextVaultSecret(self.vault_password)
+ self.vault_secrets = [('vault_secret', vault_secret),
+ ('default', vault_secret)]
+ self.vault = vault.VaultLib(self.vault_secrets)
+
+ @property
+ def vault_secret(self):
+ return vault.match_encrypt_secret(self.vault_secrets)[1]
+
+ def test_wrong_password(self):
+ plaintext = u"Ansible"
+ bob_password = "this is a different password"
+
+ bobs_secret = TextVaultSecret(bob_password)
+ bobs_secrets = [('default', bobs_secret)]
+
+ bobs_vault = vault.VaultLib(bobs_secrets)
+
+ ciphertext = bobs_vault.encrypt(plaintext, vault.match_encrypt_secret(bobs_secrets)[1])
+
+ try:
+ self.vault.decrypt(ciphertext)
+ except Exception as e:
+ self.assertIsInstance(e, errors.AnsibleError)
+ self.assertEqual(e.message, 'Decryption failed (no vault secrets were found that could decrypt)')
+
+ def _encrypt_plaintext(self, plaintext):
+ # Construct a yaml repr of a vault by hand
+ vaulted_var_bytes = self.vault.encrypt(plaintext, self.vault_secret)
+
+ # add yaml tag
+ vaulted_var = vaulted_var_bytes.decode()
+ lines = vaulted_var.splitlines()
+ lines2 = []
+ for line in lines:
+ lines2.append(' %s' % line)
+
+ vaulted_var = '\n'.join(lines2)
+ tagged_vaulted_var = u"""!vault |\n%s""" % vaulted_var
+ return tagged_vaulted_var
+
+ def _build_stream(self, yaml_text):
+ stream = NameStringIO(yaml_text)
+ stream.name = 'my.yml'
+ return stream
+
+ def _loader(self, stream):
+ return AnsibleLoader(stream, vault_secrets=self.vault.secrets)
+
+ def _load_yaml(self, yaml_text, password):
+ stream = self._build_stream(yaml_text)
+ loader = self._loader(stream)
+
+ data_from_yaml = loader.get_single_data()
+
+ return data_from_yaml
+
+ def test_dump_load_cycle(self):
+ avu = AnsibleVaultEncryptedUnicode.from_plaintext('The plaintext for test_dump_load_cycle.', self.vault, self.vault_secret)
+ self._dump_load_cycle(avu)
+
+ def test_embedded_vault_from_dump(self):
+ avu = AnsibleVaultEncryptedUnicode.from_plaintext('setec astronomy', self.vault, self.vault_secret)
+ blip = {'stuff1': [{'a dict key': 24},
+ {'shhh-ssh-secrets': avu,
+ 'nothing to see here': 'move along'}],
+ 'another key': 24.1}
+
+ blip = ['some string', 'another string', avu]
+ stream = NameStringIO()
+
+ self._dump_stream(blip, stream, dumper=AnsibleDumper)
+
+ stream.seek(0)
+
+ stream.seek(0)
+
+ loader = self._loader(stream)
+
+ data_from_yaml = loader.get_data()
+
+ stream2 = NameStringIO(u'')
+ # verify we can dump the object again
+ self._dump_stream(data_from_yaml, stream2, dumper=AnsibleDumper)
+
+ def test_embedded_vault(self):
+ plaintext_var = u"""This is the plaintext string."""
+ tagged_vaulted_var = self._encrypt_plaintext(plaintext_var)
+ another_vaulted_var = self._encrypt_plaintext(plaintext_var)
+
+ different_var = u"""A different string that is not the same as the first one."""
+ different_vaulted_var = self._encrypt_plaintext(different_var)
+
+ yaml_text = u"""---\nwebster: daniel\noed: oxford\nthe_secret: %s\nanother_secret: %s\ndifferent_secret: %s""" % (tagged_vaulted_var,
+ another_vaulted_var,
+ different_vaulted_var)
+
+ data_from_yaml = self._load_yaml(yaml_text, self.vault_password)
+ vault_string = data_from_yaml['the_secret']
+
+ self.assertEqual(plaintext_var, data_from_yaml['the_secret'])
+
+ test_dict = {}
+ test_dict[vault_string] = 'did this work?'
+
+ self.assertEqual(vault_string.data, vault_string)
+
+ # This looks weird and useless, but the object in question has a custom __eq__
+ self.assertEqual(vault_string, vault_string)
+
+ another_vault_string = data_from_yaml['another_secret']
+ different_vault_string = data_from_yaml['different_secret']
+
+ self.assertEqual(vault_string, another_vault_string)
+ self.assertNotEqual(vault_string, different_vault_string)
+
+ # More testing of __eq__/__ne__
+ self.assertTrue('some string' != vault_string)
+ self.assertNotEqual('some string', vault_string)
+
+ # Note this is a compare of the str/unicode of these, they are different types
+ # so we want to test self == other, and other == self etc
+ self.assertEqual(plaintext_var, vault_string)
+ self.assertEqual(vault_string, plaintext_var)
+ self.assertFalse(plaintext_var != vault_string)
+ self.assertFalse(vault_string != plaintext_var)
+
+
+class TestAnsibleLoaderPlay(unittest.TestCase):
+
+ def setUp(self):
+ stream = NameStringIO(u"""
+ - hosts: localhost
+ vars:
+ number: 1
+ string: Ansible
+ utf8_string: Cafè Eñyei
+ dictionary:
+ webster: daniel
+ oed: oxford
+ list:
+ - a
+ - b
+ - 1
+ - 2
+ tasks:
+ - name: Test case
+ ping:
+ data: "{{ utf8_string }}"
+
+ - name: Test 2
+ ping:
+ data: "Cafè Eñyei"
+
+ - name: Test 3
+ command: "printf 'Cafè Eñyei\\n'"
+ """)
+ self.play_filename = '/path/to/myplay.yml'
+ stream.name = self.play_filename
+ self.loader = AnsibleLoader(stream)
+ self.data = self.loader.get_single_data()
+
+ def tearDown(self):
+ pass
+
+ def test_data_complete(self):
+ self.assertEqual(len(self.data), 1)
+ self.assertIsInstance(self.data, list)
+ self.assertEqual(frozenset(self.data[0].keys()), frozenset((u'hosts', u'vars', u'tasks')))
+
+ self.assertEqual(self.data[0][u'hosts'], u'localhost')
+
+ self.assertEqual(self.data[0][u'vars'][u'number'], 1)
+ self.assertEqual(self.data[0][u'vars'][u'string'], u'Ansible')
+ self.assertEqual(self.data[0][u'vars'][u'utf8_string'], u'Cafè Eñyei')
+ self.assertEqual(self.data[0][u'vars'][u'dictionary'], {
+ u'webster': u'daniel',
+ u'oed': u'oxford'
+ })
+ self.assertEqual(self.data[0][u'vars'][u'list'], [u'a', u'b', 1, 2])
+
+ self.assertEqual(self.data[0][u'tasks'], [
+ {u'name': u'Test case', u'ping': {u'data': u'{{ utf8_string }}'}},
+ {u'name': u'Test 2', u'ping': {u'data': u'Cafè Eñyei'}},
+ {u'name': u'Test 3', u'command': u'printf \'Cafè Eñyei\n\''},
+ ])
+
+ def walk(self, data):
+ # Make sure there's no str in the data
+ self.assertNotIsInstance(data, binary_type)
+
+ # Descend into various container types
+ if isinstance(data, text_type):
+ # strings are a sequence so we have to be explicit here
+ return
+ elif isinstance(data, (Sequence, Set)):
+ for element in data:
+ self.walk(element)
+ elif isinstance(data, Mapping):
+ for k, v in data.items():
+ self.walk(k)
+ self.walk(v)
+
+ # Scalars were all checked so we're good to go
+ return
+
+ def test_no_str_in_data(self):
+ # Checks that no strings are str type
+ self.walk(self.data)
+
+ def check_vars(self):
+ # Numbers don't have line/col information yet
+ # self.assertEqual(self.data[0][u'vars'][u'number'].ansible_pos, (self.play_filename, 4, 21))
+
+ self.assertEqual(self.data[0][u'vars'][u'string'].ansible_pos, (self.play_filename, 5, 29))
+ self.assertEqual(self.data[0][u'vars'][u'utf8_string'].ansible_pos, (self.play_filename, 6, 34))
+
+ self.assertEqual(self.data[0][u'vars'][u'dictionary'].ansible_pos, (self.play_filename, 8, 23))
+ self.assertEqual(self.data[0][u'vars'][u'dictionary'][u'webster'].ansible_pos, (self.play_filename, 8, 32))
+ self.assertEqual(self.data[0][u'vars'][u'dictionary'][u'oed'].ansible_pos, (self.play_filename, 9, 28))
+
+ self.assertEqual(self.data[0][u'vars'][u'list'].ansible_pos, (self.play_filename, 11, 23))
+ self.assertEqual(self.data[0][u'vars'][u'list'][0].ansible_pos, (self.play_filename, 11, 25))
+ self.assertEqual(self.data[0][u'vars'][u'list'][1].ansible_pos, (self.play_filename, 12, 25))
+ # Numbers don't have line/col info yet
+ # self.assertEqual(self.data[0][u'vars'][u'list'][2].ansible_pos, (self.play_filename, 13, 25))
+ # self.assertEqual(self.data[0][u'vars'][u'list'][3].ansible_pos, (self.play_filename, 14, 25))
+
+ def check_tasks(self):
+ #
+ # First Task
+ #
+ self.assertEqual(self.data[0][u'tasks'][0].ansible_pos, (self.play_filename, 16, 23))
+ self.assertEqual(self.data[0][u'tasks'][0][u'name'].ansible_pos, (self.play_filename, 16, 29))
+ self.assertEqual(self.data[0][u'tasks'][0][u'ping'].ansible_pos, (self.play_filename, 18, 25))
+ self.assertEqual(self.data[0][u'tasks'][0][u'ping'][u'data'].ansible_pos, (self.play_filename, 18, 31))
+
+ #
+ # Second Task
+ #
+ self.assertEqual(self.data[0][u'tasks'][1].ansible_pos, (self.play_filename, 20, 23))
+ self.assertEqual(self.data[0][u'tasks'][1][u'name'].ansible_pos, (self.play_filename, 20, 29))
+ self.assertEqual(self.data[0][u'tasks'][1][u'ping'].ansible_pos, (self.play_filename, 22, 25))
+ self.assertEqual(self.data[0][u'tasks'][1][u'ping'][u'data'].ansible_pos, (self.play_filename, 22, 31))
+
+ #
+ # Third Task
+ #
+ self.assertEqual(self.data[0][u'tasks'][2].ansible_pos, (self.play_filename, 24, 23))
+ self.assertEqual(self.data[0][u'tasks'][2][u'name'].ansible_pos, (self.play_filename, 24, 29))
+ self.assertEqual(self.data[0][u'tasks'][2][u'command'].ansible_pos, (self.play_filename, 25, 32))
+
+ def test_line_numbers(self):
+ # Check the line/column numbers are correct
+ # Note: Remember, currently dicts begin at the start of their first entry
+ self.assertEqual(self.data[0].ansible_pos, (self.play_filename, 2, 19))
+ self.assertEqual(self.data[0][u'hosts'].ansible_pos, (self.play_filename, 2, 26))
+ self.assertEqual(self.data[0][u'vars'].ansible_pos, (self.play_filename, 4, 21))
+
+ self.check_vars()
+
+ self.assertEqual(self.data[0][u'tasks'].ansible_pos, (self.play_filename, 16, 21))
+
+ self.check_tasks()
diff --git a/test/units/parsing/yaml/test_objects.py b/test/units/parsing/yaml/test_objects.py
new file mode 100644
index 0000000..f64b708
--- /dev/null
+++ b/test/units/parsing/yaml/test_objects.py
@@ -0,0 +1,164 @@
+# This file is part of Ansible
+# -*- coding: utf-8 -*-
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright 2016, Adrian Likins <alikins@redhat.com>
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from units.compat import unittest
+
+from ansible.errors import AnsibleError
+
+from ansible.module_utils._text import to_native
+
+from ansible.parsing import vault
+from ansible.parsing.yaml.loader import AnsibleLoader
+
+# module under test
+from ansible.parsing.yaml import objects
+
+from units.mock.yaml_helper import YamlTestUtils
+from units.mock.vault_helper import TextVaultSecret
+
+
+class TestAnsibleVaultUnicodeNoVault(unittest.TestCase, YamlTestUtils):
+ def test_empty_init(self):
+ self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode)
+
+ def test_empty_string_init(self):
+ seq = ''.encode('utf8')
+ self.assert_values(seq)
+
+ def test_empty_byte_string_init(self):
+ seq = b''
+ self.assert_values(seq)
+
+ def _assert_values(self, avu, seq):
+ self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode)
+ self.assertTrue(avu.vault is None)
+ # AnsibleVaultEncryptedUnicode without a vault should never == any string
+ self.assertNotEqual(avu, seq)
+
+ def assert_values(self, seq):
+ avu = objects.AnsibleVaultEncryptedUnicode(seq)
+ self._assert_values(avu, seq)
+
+ def test_single_char(self):
+ seq = 'a'.encode('utf8')
+ self.assert_values(seq)
+
+ def test_string(self):
+ seq = 'some letters'
+ self.assert_values(seq)
+
+ def test_byte_string(self):
+ seq = 'some letters'.encode('utf8')
+ self.assert_values(seq)
+
+
+class TestAnsibleVaultEncryptedUnicode(unittest.TestCase, YamlTestUtils):
+ def setUp(self):
+ self.good_vault_password = "hunter42"
+ good_vault_secret = TextVaultSecret(self.good_vault_password)
+ self.good_vault_secrets = [('good_vault_password', good_vault_secret)]
+ self.good_vault = vault.VaultLib(self.good_vault_secrets)
+
+ # TODO: make this use two vault secret identities instead of two vaultSecrets
+ self.wrong_vault_password = 'not-hunter42'
+ wrong_vault_secret = TextVaultSecret(self.wrong_vault_password)
+ self.wrong_vault_secrets = [('wrong_vault_password', wrong_vault_secret)]
+ self.wrong_vault = vault.VaultLib(self.wrong_vault_secrets)
+
+ self.vault = self.good_vault
+ self.vault_secrets = self.good_vault_secrets
+
+ def _loader(self, stream):
+ return AnsibleLoader(stream, vault_secrets=self.vault_secrets)
+
+ def test_dump_load_cycle(self):
+ aveu = self._from_plaintext('the test string for TestAnsibleVaultEncryptedUnicode.test_dump_load_cycle')
+ self._dump_load_cycle(aveu)
+
+ def assert_values(self, avu, seq):
+ self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode)
+
+ self.assertEqual(avu, seq)
+ self.assertTrue(avu.vault is self.vault)
+ self.assertIsInstance(avu.vault, vault.VaultLib)
+
+ def _from_plaintext(self, seq):
+ id_secret = vault.match_encrypt_secret(self.good_vault_secrets)
+ return objects.AnsibleVaultEncryptedUnicode.from_plaintext(seq, vault=self.vault, secret=id_secret[1])
+
+ def _from_ciphertext(self, ciphertext):
+ avu = objects.AnsibleVaultEncryptedUnicode(ciphertext)
+ avu.vault = self.vault
+ return avu
+
+ def test_empty_init(self):
+ self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode)
+
+ def test_empty_string_init_from_plaintext(self):
+ seq = ''
+ avu = self._from_plaintext(seq)
+ self.assert_values(avu, seq)
+
+ def test_empty_unicode_init_from_plaintext(self):
+ seq = u''
+ avu = self._from_plaintext(seq)
+ self.assert_values(avu, seq)
+
+ def test_string_from_plaintext(self):
+ seq = 'some letters'
+ avu = self._from_plaintext(seq)
+ self.assert_values(avu, seq)
+
+ def test_unicode_from_plaintext(self):
+ seq = u'some letters'
+ avu = self._from_plaintext(seq)
+ self.assert_values(avu, seq)
+
+ def test_unicode_from_plaintext_encode(self):
+ seq = u'some text here'
+ avu = self._from_plaintext(seq)
+ b_avu = avu.encode('utf-8', 'strict')
+ self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode)
+ self.assertEqual(b_avu, seq.encode('utf-8', 'strict'))
+ self.assertTrue(avu.vault is self.vault)
+ self.assertIsInstance(avu.vault, vault.VaultLib)
+
+ # TODO/FIXME: make sure bad password fails differently than 'thats not encrypted'
+ def test_empty_string_wrong_password(self):
+ seq = ''
+ self.vault = self.wrong_vault
+ avu = self._from_plaintext(seq)
+
+ def compare(avu, seq):
+ return avu == seq
+
+ self.assertRaises(AnsibleError, compare, avu, seq)
+
+ def test_vaulted_utf8_value_37258(self):
+ seq = u"aöffü"
+ avu = self._from_plaintext(seq)
+ self.assert_values(avu, seq)
+
+ def test_str_vaulted_utf8_value_37258(self):
+ seq = u"aöffü"
+ avu = self._from_plaintext(seq)
+ assert str(avu) == to_native(seq)