diff options
Diffstat (limited to 'test/units/parsing')
-rw-r--r-- | test/units/parsing/__init__.py | 0 | ||||
-rw-r--r-- | test/units/parsing/fixtures/ajson.json | 19 | ||||
-rw-r--r-- | test/units/parsing/fixtures/vault.yml | 6 | ||||
-rw-r--r-- | test/units/parsing/test_ajson.py | 187 | ||||
-rw-r--r-- | test/units/parsing/test_dataloader.py | 239 | ||||
-rw-r--r-- | test/units/parsing/test_mod_args.py | 137 | ||||
-rw-r--r-- | test/units/parsing/test_splitter.py | 110 | ||||
-rw-r--r-- | test/units/parsing/test_unquote.py | 51 | ||||
-rw-r--r-- | test/units/parsing/utils/__init__.py | 0 | ||||
-rw-r--r-- | test/units/parsing/utils/test_addresses.py | 98 | ||||
-rw-r--r-- | test/units/parsing/utils/test_jsonify.py | 39 | ||||
-rw-r--r-- | test/units/parsing/utils/test_yaml.py | 34 | ||||
-rw-r--r-- | test/units/parsing/vault/__init__.py | 0 | ||||
-rw-r--r-- | test/units/parsing/vault/test_vault.py | 941 | ||||
-rw-r--r-- | test/units/parsing/vault/test_vault_editor.py | 517 | ||||
-rw-r--r-- | test/units/parsing/yaml/__init__.py | 0 | ||||
-rw-r--r-- | test/units/parsing/yaml/test_dumper.py | 103 | ||||
-rw-r--r-- | test/units/parsing/yaml/test_loader.py | 436 | ||||
-rw-r--r-- | test/units/parsing/yaml/test_objects.py | 164 |
19 files changed, 3081 insertions, 0 deletions
diff --git a/test/units/parsing/__init__.py b/test/units/parsing/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/parsing/__init__.py diff --git a/test/units/parsing/fixtures/ajson.json b/test/units/parsing/fixtures/ajson.json new file mode 100644 index 00000000..dafec0b3 --- /dev/null +++ b/test/units/parsing/fixtures/ajson.json @@ -0,0 +1,19 @@ +{ + "password": { + "__ansible_vault": "$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316562356435376162633631326264383934326565333633366238\n3863373264326461623132613931346165636465346337310a326434313830316337393263616439\n64653937313463396366633861363266633465663730303633323534363331316164623237363831\n3536333561393238370a313330316263373938326162386433313336613532653538376662306435\n3339\n" + }, + "bar": { + "baz": [ + { + "password": { + "__ansible_vault": "$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316562356435376162633631326264383934326565333633366238\n3863373264326461623132613931346165636465346337310a326434313830316337393263616439\n64653937313463396366633861363266633465663730303633323534363331316164623237363831\n3536333561393238370a313330316263373938326162386433313336613532653538376662306435\n3338\n" + } + } + ] + }, + "foo": { + "password": { + "__ansible_vault": "$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316562356435376162633631326264383934326565333633366238\n3863373264326461623132613931346165636465346337310a326434313830316337393263616439\n64653937313463396366633861363266633465663730303633323534363331316164623237363831\n3536333561393238370a313330316263373938326162386433313336613532653538376662306435\n3339\n" + } + } +} diff --git a/test/units/parsing/fixtures/vault.yml b/test/units/parsing/fixtures/vault.yml new file mode 100644 index 00000000..ca33ab25 --- /dev/null +++ b/test/units/parsing/fixtures/vault.yml @@ -0,0 +1,6 @@ +$ANSIBLE_VAULT;1.1;AES256 +33343734386261666161626433386662623039356366656637303939306563376130623138626165 +6436333766346533353463636566313332623130383662340a393835656134633665333861393331 +37666233346464636263636530626332623035633135363732623332313534306438393366323966 +3135306561356164310a343937653834643433343734653137383339323330626437313562306630 +3035 diff --git a/test/units/parsing/test_ajson.py b/test/units/parsing/test_ajson.py new file mode 100644 index 00000000..c38f43ea --- /dev/null +++ b/test/units/parsing/test_ajson.py @@ -0,0 +1,187 @@ +# Copyright 2018, Matt Martz <matt@sivel.net> +# Copyright 2019, Andrew Klychkov @Andersson007 <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import json + +import pytest + +from datetime import date, datetime +from pytz import timezone as tz + +from ansible.module_utils.common._collections_compat import Mapping +from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder +from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode +from ansible.utils.unsafe_proxy import AnsibleUnsafeText + + +def test_AnsibleJSONDecoder_vault(): + with open(os.path.join(os.path.dirname(__file__), 'fixtures/ajson.json')) as f: + data = json.load(f, cls=AnsibleJSONDecoder) + + assert isinstance(data['password'], AnsibleVaultEncryptedUnicode) + assert isinstance(data['bar']['baz'][0]['password'], AnsibleVaultEncryptedUnicode) + assert isinstance(data['foo']['password'], AnsibleVaultEncryptedUnicode) + + +def test_encode_decode_unsafe(): + data = { + 'key_value': AnsibleUnsafeText(u'{#NOTACOMMENT#}'), + 'list': [AnsibleUnsafeText(u'{#NOTACOMMENT#}')], + 'list_dict': [{'key_value': AnsibleUnsafeText(u'{#NOTACOMMENT#}')}]} + json_expected = ( + '{"key_value": {"__ansible_unsafe": "{#NOTACOMMENT#}"}, ' + '"list": [{"__ansible_unsafe": "{#NOTACOMMENT#}"}], ' + '"list_dict": [{"key_value": {"__ansible_unsafe": "{#NOTACOMMENT#}"}}]}' + ) + assert json.dumps(data, cls=AnsibleJSONEncoder, preprocess_unsafe=True, sort_keys=True) == json_expected + assert json.loads(json_expected, cls=AnsibleJSONDecoder) == data + + +def vault_data(): + """ + Prepare AnsibleVaultEncryptedUnicode test data for AnsibleJSONEncoder.default(). + + Return a list of tuples (input, expected). + """ + + with open(os.path.join(os.path.dirname(__file__), 'fixtures/ajson.json')) as f: + data = json.load(f, cls=AnsibleJSONDecoder) + + data_0 = data['password'] + data_1 = data['bar']['baz'][0]['password'] + + expected_0 = (u'$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316' + '562356435376162633631326264383934326565333633366238\n3863' + '373264326461623132613931346165636465346337310a32643431383' + '0316337393263616439\n646539373134633963666338613632666334' + '65663730303633323534363331316164623237363831\n35363335613' + '93238370a313330316263373938326162386433313336613532653538' + '376662306435\n3339\n') + + expected_1 = (u'$ANSIBLE_VAULT;1.1;AES256\n34646264306632313333393636316' + '562356435376162633631326264383934326565333633366238\n3863' + '373264326461623132613931346165636465346337310a32643431383' + '0316337393263616439\n646539373134633963666338613632666334' + '65663730303633323534363331316164623237363831\n35363335613' + '93238370a313330316263373938326162386433313336613532653538' + '376662306435\n3338\n') + + return [ + (data_0, expected_0), + (data_1, expected_1), + ] + + +class TestAnsibleJSONEncoder: + + """ + Namespace for testing AnsibleJSONEncoder. + """ + + @pytest.fixture(scope='class') + def mapping(self, request): + """ + Returns object of Mapping mock class. + + The object is used for testing handling of Mapping objects + in AnsibleJSONEncoder.default(). + Using a plain dictionary instead is not suitable because + it is handled by default encoder of the superclass (json.JSONEncoder). + """ + + class M(Mapping): + + """Mock mapping class.""" + + def __init__(self, *args, **kwargs): + self.__dict__.update(*args, **kwargs) + + def __getitem__(self, key): + return self.__dict__[key] + + def __iter__(self): + return iter(self.__dict__) + + def __len__(self): + return len(self.__dict__) + + return M(request.param) + + @pytest.fixture + def ansible_json_encoder(self): + """Return AnsibleJSONEncoder object.""" + return AnsibleJSONEncoder() + + ############### + # Test methods: + + @pytest.mark.parametrize( + 'test_input,expected', + [ + (datetime(2019, 5, 14, 13, 39, 38, 569047), '2019-05-14T13:39:38.569047'), + (datetime(2019, 5, 14, 13, 47, 16, 923866), '2019-05-14T13:47:16.923866'), + (date(2019, 5, 14), '2019-05-14'), + (date(2020, 5, 14), '2020-05-14'), + (datetime(2019, 6, 15, 14, 45, tzinfo=tz('UTC')), '2019-06-15T14:45:00+00:00'), + (datetime(2019, 6, 15, 14, 45, tzinfo=tz('Europe/Helsinki')), '2019-06-15T14:45:00+01:40'), + ] + ) + def test_date_datetime(self, ansible_json_encoder, test_input, expected): + """ + Test for passing datetime.date or datetime.datetime objects to AnsibleJSONEncoder.default(). + """ + assert ansible_json_encoder.default(test_input) == expected + + @pytest.mark.parametrize( + 'mapping,expected', + [ + ({1: 1}, {1: 1}), + ({2: 2}, {2: 2}), + ({1: 2}, {1: 2}), + ({2: 1}, {2: 1}), + ], indirect=['mapping'], + ) + def test_mapping(self, ansible_json_encoder, mapping, expected): + """ + Test for passing Mapping object to AnsibleJSONEncoder.default(). + """ + assert ansible_json_encoder.default(mapping) == expected + + @pytest.mark.parametrize('test_input,expected', vault_data()) + def test_ansible_json_decoder_vault(self, ansible_json_encoder, test_input, expected): + """ + Test for passing AnsibleVaultEncryptedUnicode to AnsibleJSONEncoder.default(). + """ + assert ansible_json_encoder.default(test_input) == {'__ansible_vault': expected} + assert json.dumps(test_input, cls=AnsibleJSONEncoder, preprocess_unsafe=True) == '{"__ansible_vault": "%s"}' % expected.replace('\n', '\\n') + + @pytest.mark.parametrize( + 'test_input,expected', + [ + ({1: 'first'}, {1: 'first'}), + ({2: 'second'}, {2: 'second'}), + ] + ) + def test_default_encoder(self, ansible_json_encoder, test_input, expected): + """ + Test for the default encoder of AnsibleJSONEncoder.default(). + + If objects of different classes that are not tested above were passed, + AnsibleJSONEncoder.default() invokes 'default()' method of json.JSONEncoder superclass. + """ + assert ansible_json_encoder.default(test_input) == expected + + @pytest.mark.parametrize('test_input', [1, 1.1, 'string', [1, 2], set('set'), True, None]) + def test_default_encoder_unserializable(self, ansible_json_encoder, test_input): + """ + Test for the default encoder of AnsibleJSONEncoder.default(), not serializable objects. + + It must fail with TypeError 'object is not serializable'. + """ + with pytest.raises(TypeError): + ansible_json_encoder.default(test_input) diff --git a/test/units/parsing/test_dataloader.py b/test/units/parsing/test_dataloader.py new file mode 100644 index 00000000..3cc8d451 --- /dev/null +++ b/test/units/parsing/test_dataloader.py @@ -0,0 +1,239 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from units.compat import unittest +from units.compat.mock import patch, mock_open +from ansible.errors import AnsibleParserError, yaml_strings, AnsibleFileNotFound +from ansible.parsing.vault import AnsibleVaultError +from ansible.module_utils._text import to_text +from ansible.module_utils.six import PY3 + +from units.mock.vault_helper import TextVaultSecret +from ansible.parsing.dataloader import DataLoader + +from units.mock.path import mock_unfrackpath_noop + + +class TestDataLoader(unittest.TestCase): + + def setUp(self): + self._loader = DataLoader() + + @patch('os.path.exists') + def test__is_role(self, p_exists): + p_exists.side_effect = lambda p: p == b'test_path/tasks/main.yml' + self.assertTrue(self._loader._is_role('test_path/tasks')) + self.assertTrue(self._loader._is_role('test_path/')) + + @patch.object(DataLoader, '_get_file_contents') + def test_parse_json_from_file(self, mock_def): + mock_def.return_value = (b"""{"a": 1, "b": 2, "c": 3}""", True) + output = self._loader.load_from_file('dummy_json.txt') + self.assertEqual(output, dict(a=1, b=2, c=3)) + + @patch.object(DataLoader, '_get_file_contents') + def test_parse_yaml_from_file(self, mock_def): + mock_def.return_value = (b""" + a: 1 + b: 2 + c: 3 + """, True) + output = self._loader.load_from_file('dummy_yaml.txt') + self.assertEqual(output, dict(a=1, b=2, c=3)) + + @patch.object(DataLoader, '_get_file_contents') + def test_parse_fail_from_file(self, mock_def): + mock_def.return_value = (b""" + TEXT: + *** + NOT VALID + """, True) + self.assertRaises(AnsibleParserError, self._loader.load_from_file, 'dummy_yaml_bad.txt') + + @patch('ansible.errors.AnsibleError._get_error_lines_from_file') + @patch.object(DataLoader, '_get_file_contents') + def test_tab_error(self, mock_def, mock_get_error_lines): + mock_def.return_value = (u"""---\nhosts: localhost\nvars:\n foo: bar\n\tblip: baz""", True) + mock_get_error_lines.return_value = ('''\tblip: baz''', '''..foo: bar''') + with self.assertRaises(AnsibleParserError) as cm: + self._loader.load_from_file('dummy_yaml_text.txt') + self.assertIn(yaml_strings.YAML_COMMON_LEADING_TAB_ERROR, str(cm.exception)) + self.assertIn('foo: bar', str(cm.exception)) + + @patch('ansible.parsing.dataloader.unfrackpath', mock_unfrackpath_noop) + @patch.object(DataLoader, '_is_role') + def test_path_dwim_relative(self, mock_is_role): + """ + simulate a nested dynamic include: + + playbook.yml: + - hosts: localhost + roles: + - { role: 'testrole' } + + testrole/tasks/main.yml: + - include: "include1.yml" + static: no + + testrole/tasks/include1.yml: + - include: include2.yml + static: no + + testrole/tasks/include2.yml: + - debug: msg="blah" + """ + mock_is_role.return_value = False + with patch('os.path.exists') as mock_os_path_exists: + mock_os_path_exists.return_value = False + self._loader.path_dwim_relative('/tmp/roles/testrole/tasks', 'tasks', 'included2.yml') + + # Fetch first args for every call + # mock_os_path_exists.assert_any_call isn't used because os.path.normpath must be used in order to compare paths + called_args = [os.path.normpath(to_text(call[0][0])) for call in mock_os_path_exists.call_args_list] + + # 'path_dwim_relative' docstrings say 'with or without explicitly named dirname subdirs': + self.assertIn('/tmp/roles/testrole/tasks/included2.yml', called_args) + self.assertIn('/tmp/roles/testrole/tasks/tasks/included2.yml', called_args) + + # relative directories below are taken in account too: + self.assertIn('tasks/included2.yml', called_args) + self.assertIn('included2.yml', called_args) + + def test_path_dwim_root(self): + self.assertEqual(self._loader.path_dwim('/'), '/') + + def test_path_dwim_home(self): + self.assertEqual(self._loader.path_dwim('~'), os.path.expanduser('~')) + + def test_path_dwim_tilde_slash(self): + self.assertEqual(self._loader.path_dwim('~/'), os.path.expanduser('~')) + + def test_get_real_file(self): + self.assertEqual(self._loader.get_real_file(__file__), __file__) + + def test_is_file(self): + self.assertTrue(self._loader.is_file(__file__)) + + def test_is_directory_positive(self): + self.assertTrue(self._loader.is_directory(os.path.dirname(__file__))) + + def test_get_file_contents_none_path(self): + self.assertRaisesRegexp(AnsibleParserError, 'Invalid filename', + self._loader._get_file_contents, None) + + def test_get_file_contents_non_existent_path(self): + self.assertRaises(AnsibleFileNotFound, self._loader._get_file_contents, '/non_existent_file') + + +class TestPathDwimRelativeDataLoader(unittest.TestCase): + + def setUp(self): + self._loader = DataLoader() + + def test_all_slash(self): + self.assertEqual(self._loader.path_dwim_relative('/', '/', '/'), '/') + + def test_path_endswith_role(self): + self.assertEqual(self._loader.path_dwim_relative(path='foo/bar/tasks/', dirname='/', source='/'), '/') + + def test_path_endswith_role_main_yml(self): + self.assertIn('main.yml', self._loader.path_dwim_relative(path='foo/bar/tasks/', dirname='/', source='main.yml')) + + def test_path_endswith_role_source_tilde(self): + self.assertEqual(self._loader.path_dwim_relative(path='foo/bar/tasks/', dirname='/', source='~/'), os.path.expanduser('~')) + + +class TestPathDwimRelativeStackDataLoader(unittest.TestCase): + + def setUp(self): + self._loader = DataLoader() + + def test_none(self): + self.assertRaisesRegexp(AnsibleFileNotFound, 'on the Ansible Controller', self._loader.path_dwim_relative_stack, None, None, None) + + def test_empty_strings(self): + self.assertEqual(self._loader.path_dwim_relative_stack('', '', ''), './') + + def test_empty_lists(self): + self.assertEqual(self._loader.path_dwim_relative_stack([], '', '~/'), os.path.expanduser('~')) + + def test_all_slash(self): + self.assertEqual(self._loader.path_dwim_relative_stack('/', '/', '/'), '/') + + def test_path_endswith_role(self): + self.assertEqual(self._loader.path_dwim_relative_stack(paths=['foo/bar/tasks/'], dirname='/', source='/'), '/') + + def test_path_endswith_role_source_tilde(self): + self.assertEqual(self._loader.path_dwim_relative_stack(paths=['foo/bar/tasks/'], dirname='/', source='~/'), os.path.expanduser('~')) + + def test_path_endswith_role_source_main_yml(self): + self.assertRaises(AnsibleFileNotFound, self._loader.path_dwim_relative_stack, ['foo/bar/tasks/'], '/', 'main.yml') + + def test_path_endswith_role_source_main_yml_source_in_dirname(self): + self.assertRaises(AnsibleFileNotFound, self._loader.path_dwim_relative_stack, 'foo/bar/tasks/', 'tasks', 'tasks/main.yml') + + +class TestDataLoaderWithVault(unittest.TestCase): + + def setUp(self): + self._loader = DataLoader() + vault_secrets = [('default', TextVaultSecret('ansible'))] + self._loader.set_vault_secrets(vault_secrets) + self.test_vault_data_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'vault.yml') + + def tearDown(self): + pass + + def test_get_real_file_vault(self): + real_file_path = self._loader.get_real_file(self.test_vault_data_path) + self.assertTrue(os.path.exists(real_file_path)) + + def test_get_real_file_vault_no_vault(self): + self._loader.set_vault_secrets(None) + self.assertRaises(AnsibleParserError, self._loader.get_real_file, self.test_vault_data_path) + + def test_get_real_file_vault_wrong_password(self): + wrong_vault = [('default', TextVaultSecret('wrong_password'))] + self._loader.set_vault_secrets(wrong_vault) + self.assertRaises(AnsibleVaultError, self._loader.get_real_file, self.test_vault_data_path) + + def test_get_real_file_not_a_path(self): + self.assertRaisesRegexp(AnsibleParserError, 'Invalid filename', self._loader.get_real_file, None) + + @patch.multiple(DataLoader, path_exists=lambda s, x: True, is_file=lambda s, x: True) + def test_parse_from_vault_1_1_file(self): + vaulted_data = """$ANSIBLE_VAULT;1.1;AES256 +33343734386261666161626433386662623039356366656637303939306563376130623138626165 +6436333766346533353463636566313332623130383662340a393835656134633665333861393331 +37666233346464636263636530626332623035633135363732623332313534306438393366323966 +3135306561356164310a343937653834643433343734653137383339323330626437313562306630 +3035 +""" + if PY3: + builtins_name = 'builtins' + else: + builtins_name = '__builtin__' + + with patch(builtins_name + '.open', mock_open(read_data=vaulted_data.encode('utf-8'))): + output = self._loader.load_from_file('dummy_vault.txt') + self.assertEqual(output, dict(foo='bar')) diff --git a/test/units/parsing/test_mod_args.py b/test/units/parsing/test_mod_args.py new file mode 100644 index 00000000..50c3b331 --- /dev/null +++ b/test/units/parsing/test_mod_args.py @@ -0,0 +1,137 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# Copyright 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest +import re + +from ansible.errors import AnsibleParserError +from ansible.parsing.mod_args import ModuleArgsParser +from ansible.utils.sentinel import Sentinel + + +class TestModArgsDwim: + + # TODO: add tests that construct ModuleArgsParser with a task reference + # TODO: verify the AnsibleError raised on failure knows the task + # and the task knows the line numbers + + INVALID_MULTIPLE_ACTIONS = ( + ({'action': 'shell echo hi', 'local_action': 'shell echo hi'}, "action and local_action are mutually exclusive"), + ({'action': 'shell echo hi', 'shell': 'echo hi'}, "conflicting action statements: shell, shell"), + ({'local_action': 'shell echo hi', 'shell': 'echo hi'}, "conflicting action statements: shell, shell"), + ) + + def _debug(self, mod, args, to): + print("RETURNED module = {0}".format(mod)) + print(" args = {0}".format(args)) + print(" to = {0}".format(to)) + + def test_basic_shell(self): + m = ModuleArgsParser(dict(shell='echo hi')) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod == 'shell' + assert args == dict( + _raw_params='echo hi', + ) + assert to is Sentinel + + def test_basic_command(self): + m = ModuleArgsParser(dict(command='echo hi')) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod == 'command' + assert args == dict( + _raw_params='echo hi', + ) + assert to is Sentinel + + def test_shell_with_modifiers(self): + m = ModuleArgsParser(dict(shell='/bin/foo creates=/tmp/baz removes=/tmp/bleep')) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod == 'shell' + assert args == dict( + creates='/tmp/baz', + removes='/tmp/bleep', + _raw_params='/bin/foo', + ) + assert to is Sentinel + + def test_normal_usage(self): + m = ModuleArgsParser(dict(copy='src=a dest=b')) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod, 'copy' + assert args, dict(src='a', dest='b') + assert to is Sentinel + + def test_complex_args(self): + m = ModuleArgsParser(dict(copy=dict(src='a', dest='b'))) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod, 'copy' + assert args, dict(src='a', dest='b') + assert to is Sentinel + + def test_action_with_complex(self): + m = ModuleArgsParser(dict(action=dict(module='copy', src='a', dest='b'))) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod == 'copy' + assert args == dict(src='a', dest='b') + assert to is Sentinel + + def test_action_with_complex_and_complex_args(self): + m = ModuleArgsParser(dict(action=dict(module='copy', args=dict(src='a', dest='b')))) + mod, args, to = m.parse() + self._debug(mod, args, to) + + assert mod == 'copy' + assert args == dict(src='a', dest='b') + assert to is Sentinel + + def test_local_action_string(self): + m = ModuleArgsParser(dict(local_action='copy src=a dest=b')) + mod, args, delegate_to = m.parse() + self._debug(mod, args, delegate_to) + + assert mod == 'copy' + assert args == dict(src='a', dest='b') + assert delegate_to == 'localhost' + + @pytest.mark.parametrize("args_dict, msg", INVALID_MULTIPLE_ACTIONS) + def test_multiple_actions(self, args_dict, msg): + m = ModuleArgsParser(args_dict) + with pytest.raises(AnsibleParserError) as err: + m.parse() + + assert err.value.args[0] == msg + + def test_multiple_actions(self): + args_dict = {'ping': 'data=hi', 'shell': 'echo hi'} + m = ModuleArgsParser(args_dict) + with pytest.raises(AnsibleParserError) as err: + m.parse() + + assert err.value.args[0].startswith("conflicting action statements: ") + actions = set(re.search(r'(\w+), (\w+)', err.value.args[0]).groups()) + assert actions == set(['ping', 'shell']) + + def test_bogus_action(self): + args_dict = {'bogusaction': {}} + m = ModuleArgsParser(args_dict) + with pytest.raises(AnsibleParserError) as err: + m.parse() + + assert err.value.args[0].startswith("couldn't resolve module/action 'bogusaction'") diff --git a/test/units/parsing/test_splitter.py b/test/units/parsing/test_splitter.py new file mode 100644 index 00000000..a37de0f9 --- /dev/null +++ b/test/units/parsing/test_splitter.py @@ -0,0 +1,110 @@ +# coding: utf-8 +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.parsing.splitter import split_args, parse_kv + +import pytest + +SPLIT_DATA = ( + (u'a', + [u'a'], + {u'_raw_params': u'a'}), + (u'a=b', + [u'a=b'], + {u'a': u'b'}), + (u'a="foo bar"', + [u'a="foo bar"'], + {u'a': u'foo bar'}), + (u'"foo bar baz"', + [u'"foo bar baz"'], + {u'_raw_params': '"foo bar baz"'}), + (u'foo bar baz', + [u'foo', u'bar', u'baz'], + {u'_raw_params': u'foo bar baz'}), + (u'a=b c="foo bar"', + [u'a=b', u'c="foo bar"'], + {u'a': u'b', u'c': u'foo bar'}), + (u'a="echo \\"hello world\\"" b=bar', + [u'a="echo \\"hello world\\""', u'b=bar'], + {u'a': u'echo "hello world"', u'b': u'bar'}), + (u'a="multi\nline"', + [u'a="multi\nline"'], + {u'a': u'multi\nline'}), + (u'a="blank\n\nline"', + [u'a="blank\n\nline"'], + {u'a': u'blank\n\nline'}), + (u'a="blank\n\n\nlines"', + [u'a="blank\n\n\nlines"'], + {u'a': u'blank\n\n\nlines'}), + (u'a="a long\nmessage\\\nabout a thing\n"', + [u'a="a long\nmessage\\\nabout a thing\n"'], + {u'a': u'a long\nmessage\\\nabout a thing\n'}), + (u'a="multiline\nmessage1\\\n" b="multiline\nmessage2\\\n"', + [u'a="multiline\nmessage1\\\n"', u'b="multiline\nmessage2\\\n"'], + {u'a': 'multiline\nmessage1\\\n', u'b': u'multiline\nmessage2\\\n'}), + (u'a={{jinja}}', + [u'a={{jinja}}'], + {u'a': u'{{jinja}}'}), + (u'a={{ jinja }}', + [u'a={{ jinja }}'], + {u'a': u'{{ jinja }}'}), + (u'a="{{jinja}}"', + [u'a="{{jinja}}"'], + {u'a': u'{{jinja}}'}), + (u'a={{ jinja }}{{jinja2}}', + [u'a={{ jinja }}{{jinja2}}'], + {u'a': u'{{ jinja }}{{jinja2}}'}), + (u'a="{{ jinja }}{{jinja2}}"', + [u'a="{{ jinja }}{{jinja2}}"'], + {u'a': u'{{ jinja }}{{jinja2}}'}), + (u'a={{jinja}} b={{jinja2}}', + [u'a={{jinja}}', u'b={{jinja2}}'], + {u'a': u'{{jinja}}', u'b': u'{{jinja2}}'}), + (u'a="{{jinja}}\n" b="{{jinja2}}\n"', + [u'a="{{jinja}}\n"', u'b="{{jinja2}}\n"'], + {u'a': u'{{jinja}}\n', u'b': u'{{jinja2}}\n'}), + (u'a="café eñyei"', + [u'a="café eñyei"'], + {u'a': u'café eñyei'}), + (u'a=café b=eñyei', + [u'a=café', u'b=eñyei'], + {u'a': u'café', u'b': u'eñyei'}), + (u'a={{ foo | some_filter(\' \', " ") }} b=bar', + [u'a={{ foo | some_filter(\' \', " ") }}', u'b=bar'], + {u'a': u'{{ foo | some_filter(\' \', " ") }}', u'b': u'bar'}), + (u'One\n Two\n Three\n', + [u'One\n ', u'Two\n ', u'Three\n'], + {u'_raw_params': u'One\n Two\n Three\n'}), +) + +SPLIT_ARGS = ((test[0], test[1]) for test in SPLIT_DATA) +PARSE_KV = ((test[0], test[2]) for test in SPLIT_DATA) + + +@pytest.mark.parametrize("args, expected", SPLIT_ARGS) +def test_split_args(args, expected): + assert split_args(args) == expected + + +@pytest.mark.parametrize("args, expected", PARSE_KV) +def test_parse_kv(args, expected): + assert parse_kv(args) == expected diff --git a/test/units/parsing/test_unquote.py b/test/units/parsing/test_unquote.py new file mode 100644 index 00000000..4b4260e7 --- /dev/null +++ b/test/units/parsing/test_unquote.py @@ -0,0 +1,51 @@ +# coding: utf-8 +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.parsing.quoting import unquote + +import pytest + +UNQUOTE_DATA = ( + (u'1', u'1'), + (u'\'1\'', u'1'), + (u'"1"', u'1'), + (u'"1 \'2\'"', u'1 \'2\''), + (u'\'1 "2"\'', u'1 "2"'), + (u'\'1 \'2\'\'', u'1 \'2\''), + (u'"1\\"', u'"1\\"'), + (u'\'1\\\'', u'\'1\\\''), + (u'"1 \\"2\\" 3"', u'1 \\"2\\" 3'), + (u'\'1 \\\'2\\\' 3\'', u'1 \\\'2\\\' 3'), + (u'"', u'"'), + (u'\'', u'\''), + # Not entirely sure these are good but they match the current + # behaviour + (u'"1""2"', u'1""2'), + (u'\'1\'\'2\'', u'1\'\'2'), + (u'"1" 2 "3"', u'1" 2 "3'), + (u'"1"\'2\'"3"', u'1"\'2\'"3'), +) + + +@pytest.mark.parametrize("quoted, expected", UNQUOTE_DATA) +def test_unquote(quoted, expected): + assert unquote(quoted) == expected diff --git a/test/units/parsing/utils/__init__.py b/test/units/parsing/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/parsing/utils/__init__.py diff --git a/test/units/parsing/utils/test_addresses.py b/test/units/parsing/utils/test_addresses.py new file mode 100644 index 00000000..4f7304f5 --- /dev/null +++ b/test/units/parsing/utils/test_addresses.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import unittest + +from ansible.parsing.utils.addresses import parse_address + + +class TestParseAddress(unittest.TestCase): + + tests = { + # IPv4 addresses + '192.0.2.3': ['192.0.2.3', None], + '192.0.2.3:23': ['192.0.2.3', 23], + + # IPv6 addresses + '::': ['::', None], + '::1': ['::1', None], + '[::1]:442': ['::1', 442], + 'abcd:ef98:7654:3210:abcd:ef98:7654:3210': ['abcd:ef98:7654:3210:abcd:ef98:7654:3210', None], + '[abcd:ef98:7654:3210:abcd:ef98:7654:3210]:42': ['abcd:ef98:7654:3210:abcd:ef98:7654:3210', 42], + '1234:5678:9abc:def0:1234:5678:9abc:def0': ['1234:5678:9abc:def0:1234:5678:9abc:def0', None], + '1234::9abc:def0:1234:5678:9abc:def0': ['1234::9abc:def0:1234:5678:9abc:def0', None], + '1234:5678::def0:1234:5678:9abc:def0': ['1234:5678::def0:1234:5678:9abc:def0', None], + '1234:5678:9abc::1234:5678:9abc:def0': ['1234:5678:9abc::1234:5678:9abc:def0', None], + '1234:5678:9abc:def0::5678:9abc:def0': ['1234:5678:9abc:def0::5678:9abc:def0', None], + '1234:5678:9abc:def0:1234::9abc:def0': ['1234:5678:9abc:def0:1234::9abc:def0', None], + '1234:5678:9abc:def0:1234:5678::def0': ['1234:5678:9abc:def0:1234:5678::def0', None], + '1234:5678:9abc:def0:1234:5678::': ['1234:5678:9abc:def0:1234:5678::', None], + '::9abc:def0:1234:5678:9abc:def0': ['::9abc:def0:1234:5678:9abc:def0', None], + '0:0:0:0:0:ffff:1.2.3.4': ['0:0:0:0:0:ffff:1.2.3.4', None], + '0:0:0:0:0:0:1.2.3.4': ['0:0:0:0:0:0:1.2.3.4', None], + '::ffff:1.2.3.4': ['::ffff:1.2.3.4', None], + '::1.2.3.4': ['::1.2.3.4', None], + '1234::': ['1234::', None], + + # Hostnames + 'some-host': ['some-host', None], + 'some-host:80': ['some-host', 80], + 'some.host.com:492': ['some.host.com', 492], + '[some.host.com]:493': ['some.host.com', 493], + 'a-b.3foo_bar.com:23': ['a-b.3foo_bar.com', 23], + u'fóöbär': [u'fóöbär', None], + u'fóöbär:32': [u'fóöbär', 32], + u'fóöbär.éxàmplê.com:632': [u'fóöbär.éxàmplê.com', 632], + + # Various errors + '': [None, None], + 'some..host': [None, None], + 'some.': [None, None], + '[example.com]': [None, None], + 'some-': [None, None], + 'some-.foo.com': [None, None], + 'some.-foo.com': [None, None], + } + + range_tests = { + '192.0.2.[3:10]': ['192.0.2.[3:10]', None], + '192.0.2.[3:10]:23': ['192.0.2.[3:10]', 23], + 'abcd:ef98::7654:[1:9]': ['abcd:ef98::7654:[1:9]', None], + '[abcd:ef98::7654:[6:32]]:2222': ['abcd:ef98::7654:[6:32]', 2222], + '[abcd:ef98::7654:[9ab3:fcb7]]:2222': ['abcd:ef98::7654:[9ab3:fcb7]', 2222], + u'fóöb[a:c]r.éxàmplê.com:632': [u'fóöb[a:c]r.éxàmplê.com', 632], + '[a:b]foo.com': ['[a:b]foo.com', None], + 'foo[a:b].com': ['foo[a:b].com', None], + 'foo[a:b]:42': ['foo[a:b]', 42], + 'foo[a-b]-.com': [None, None], + 'foo[a-b]:32': [None, None], + 'foo[x-y]': [None, None], + } + + def test_without_ranges(self): + for t in self.tests: + test = self.tests[t] + + try: + (host, port) = parse_address(t) + except Exception: + host = None + port = None + + assert host == test[0] + assert port == test[1] + + def test_with_ranges(self): + for t in self.range_tests: + test = self.range_tests[t] + + try: + (host, port) = parse_address(t, allow_ranges=True) + except Exception: + host = None + port = None + + assert host == test[0] + assert port == test[1] diff --git a/test/units/parsing/utils/test_jsonify.py b/test/units/parsing/utils/test_jsonify.py new file mode 100644 index 00000000..37be7824 --- /dev/null +++ b/test/units/parsing/utils/test_jsonify.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# (c) 2016, James Cammarata <jimi@sngx.net> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from units.compat import unittest +from ansible.parsing.utils.jsonify import jsonify + + +class TestJsonify(unittest.TestCase): + def test_jsonify_simple(self): + self.assertEqual(jsonify(dict(a=1, b=2, c=3)), '{"a": 1, "b": 2, "c": 3}') + + def test_jsonify_simple_format(self): + res = jsonify(dict(a=1, b=2, c=3), format=True) + cleaned = "".join([x.strip() for x in res.splitlines()]) + self.assertEqual(cleaned, '{"a": 1,"b": 2,"c": 3}') + + def test_jsonify_unicode(self): + self.assertEqual(jsonify(dict(toshio=u'くらとみ')), u'{"toshio": "くらとみ"}') + + def test_jsonify_empty(self): + self.assertEqual(jsonify(None), '{}') diff --git a/test/units/parsing/utils/test_yaml.py b/test/units/parsing/utils/test_yaml.py new file mode 100644 index 00000000..27b2905a --- /dev/null +++ b/test/units/parsing/utils/test_yaml.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# (c) 2017, Ansible Project +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible.errors import AnsibleParserError +from ansible.parsing.utils.yaml import from_yaml + + +def test_from_yaml_simple(): + assert from_yaml(u'---\n- test: 1\n test2: "2"\n- caf\xe9: "caf\xe9"') == [{u'test': 1, u'test2': u"2"}, {u"caf\xe9": u"caf\xe9"}] + + +def test_bad_yaml(): + with pytest.raises(AnsibleParserError): + from_yaml(u'foo: bar: baz') diff --git a/test/units/parsing/vault/__init__.py b/test/units/parsing/vault/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/parsing/vault/__init__.py diff --git a/test/units/parsing/vault/test_vault.py b/test/units/parsing/vault/test_vault.py new file mode 100644 index 00000000..a9c4fc9e --- /dev/null +++ b/test/units/parsing/vault/test_vault.py @@ -0,0 +1,941 @@ +# -*- coding: utf-8 -*- +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# (c) 2016, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import binascii +import io +import os +import tempfile + +from binascii import hexlify +import pytest + +from units.compat import unittest +from units.compat.mock import patch, MagicMock + +from ansible import errors +from ansible.module_utils import six +from ansible.module_utils._text import to_bytes, to_text +from ansible.parsing import vault + +from units.mock.loader import DictDataLoader +from units.mock.vault_helper import TextVaultSecret + + +class TestUnhexlify(unittest.TestCase): + def test(self): + b_plain_data = b'some text to hexlify' + b_data = hexlify(b_plain_data) + res = vault._unhexlify(b_data) + self.assertEqual(res, b_plain_data) + + def test_odd_length(self): + b_data = b'123456789abcdefghijklmnopqrstuvwxyz' + + self.assertRaisesRegexp(vault.AnsibleVaultFormatError, + '.*Vault format unhexlify error.*', + vault._unhexlify, + b_data) + + def test_nonhex(self): + b_data = b'6z36316566653264333665333637623064303639353237620a636366633565663263336335656532' + + self.assertRaisesRegexp(vault.AnsibleVaultFormatError, + '.*Vault format unhexlify error.*Non-hexadecimal digit found', + vault._unhexlify, + b_data) + + +class TestParseVaulttext(unittest.TestCase): + def test(self): + vaulttext_envelope = u'''$ANSIBLE_VAULT;1.1;AES256 +33363965326261303234626463623963633531343539616138316433353830356566396130353436 +3562643163366231316662386565383735653432386435610a306664636137376132643732393835 +63383038383730306639353234326630666539346233376330303938323639306661313032396437 +6233623062366136310a633866373936313238333730653739323461656662303864663666653563 +3138''' + + b_vaulttext_envelope = to_bytes(vaulttext_envelope, errors='strict', encoding='utf-8') + b_vaulttext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext_envelope) + res = vault.parse_vaulttext(b_vaulttext) + self.assertIsInstance(res[0], bytes) + self.assertIsInstance(res[1], bytes) + self.assertIsInstance(res[2], bytes) + + def test_non_hex(self): + vaulttext_envelope = u'''$ANSIBLE_VAULT;1.1;AES256 +3336396J326261303234626463623963633531343539616138316433353830356566396130353436 +3562643163366231316662386565383735653432386435610a306664636137376132643732393835 +63383038383730306639353234326630666539346233376330303938323639306661313032396437 +6233623062366136310a633866373936313238333730653739323461656662303864663666653563 +3138''' + + b_vaulttext_envelope = to_bytes(vaulttext_envelope, errors='strict', encoding='utf-8') + b_vaulttext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext_envelope) + self.assertRaisesRegexp(vault.AnsibleVaultFormatError, + '.*Vault format unhexlify error.*Non-hexadecimal digit found', + vault.parse_vaulttext, + b_vaulttext_envelope) + + +class TestVaultSecret(unittest.TestCase): + def test(self): + secret = vault.VaultSecret() + secret.load() + self.assertIsNone(secret._bytes) + + def test_bytes(self): + some_text = u'私はガラスを食べられます。それは私を傷つけません。' + _bytes = to_bytes(some_text) + secret = vault.VaultSecret(_bytes) + secret.load() + self.assertEqual(secret.bytes, _bytes) + + +class TestPromptVaultSecret(unittest.TestCase): + def test_empty_prompt_formats(self): + secret = vault.PromptVaultSecret(vault_id='test_id', prompt_formats=[]) + secret.load() + self.assertIsNone(secret._bytes) + + @patch('ansible.parsing.vault.display.prompt', return_value='the_password') + def test_prompt_formats_none(self, mock_display_prompt): + secret = vault.PromptVaultSecret(vault_id='test_id') + secret.load() + self.assertEqual(secret._bytes, b'the_password') + + @patch('ansible.parsing.vault.display.prompt', return_value='the_password') + def test_custom_prompt(self, mock_display_prompt): + secret = vault.PromptVaultSecret(vault_id='test_id', + prompt_formats=['The cow flies at midnight: ']) + secret.load() + self.assertEqual(secret._bytes, b'the_password') + + @patch('ansible.parsing.vault.display.prompt', side_effect=EOFError) + def test_prompt_eoferror(self, mock_display_prompt): + secret = vault.PromptVaultSecret(vault_id='test_id') + self.assertRaisesRegexp(vault.AnsibleVaultError, + 'EOFError.*test_id', + secret.load) + + @patch('ansible.parsing.vault.display.prompt', side_effect=['first_password', 'second_password']) + def test_prompt_passwords_dont_match(self, mock_display_prompt): + secret = vault.PromptVaultSecret(vault_id='test_id', + prompt_formats=['Vault password: ', + 'Confirm Vault password: ']) + self.assertRaisesRegexp(errors.AnsibleError, + 'Passwords do not match', + secret.load) + + +class TestFileVaultSecret(unittest.TestCase): + def setUp(self): + self.vault_password = "test-vault-password" + text_secret = TextVaultSecret(self.vault_password) + self.vault_secrets = [('foo', text_secret)] + + def test(self): + secret = vault.FileVaultSecret() + self.assertIsNone(secret._bytes) + self.assertIsNone(secret._text) + + def test_repr_empty(self): + secret = vault.FileVaultSecret() + self.assertEqual(repr(secret), "FileVaultSecret()") + + def test_repr(self): + tmp_file = tempfile.NamedTemporaryFile(delete=False) + fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) + + secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) + filename = tmp_file.name + tmp_file.close() + self.assertEqual(repr(secret), "FileVaultSecret(filename='%s')" % filename) + + def test_empty_bytes(self): + secret = vault.FileVaultSecret() + self.assertIsNone(secret.bytes) + + def test_file(self): + password = 'some password' + + tmp_file = tempfile.NamedTemporaryFile(delete=False) + tmp_file.write(to_bytes(password)) + tmp_file.close() + + fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) + + secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) + secret.load() + + os.unlink(tmp_file.name) + + self.assertEqual(secret.bytes, to_bytes(password)) + + def test_file_empty(self): + + tmp_file = tempfile.NamedTemporaryFile(delete=False) + tmp_file.write(to_bytes('')) + tmp_file.close() + + fake_loader = DictDataLoader({tmp_file.name: ''}) + + secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) + self.assertRaisesRegexp(vault.AnsibleVaultPasswordError, + 'Invalid vault password was provided from file.*%s' % tmp_file.name, + secret.load) + + os.unlink(tmp_file.name) + + def test_file_encrypted(self): + vault_password = "test-vault-password" + text_secret = TextVaultSecret(vault_password) + vault_secrets = [('foo', text_secret)] + + password = 'some password' + # 'some password' encrypted with 'test-ansible-password' + + password_file_content = '''$ANSIBLE_VAULT;1.1;AES256 +61393863643638653437313566313632306462383837303132346434616433313438353634613762 +3334363431623364386164616163326537366333353663650a663634306232363432626162353665 +39623061353266373631636331643761306665343731376633623439313138396330346237653930 +6432643864346136640a653364386634666461306231353765636662316335613235383565306437 +3737 +''' + + tmp_file = tempfile.NamedTemporaryFile(delete=False) + tmp_file.write(to_bytes(password_file_content)) + tmp_file.close() + + fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) + fake_loader._vault.secrets = vault_secrets + + secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) + secret.load() + + os.unlink(tmp_file.name) + + self.assertEqual(secret.bytes, to_bytes(password)) + + def test_file_not_a_directory(self): + filename = '/dev/null/foobar' + fake_loader = DictDataLoader({filename: 'sdfadf'}) + + secret = vault.FileVaultSecret(loader=fake_loader, filename=filename) + self.assertRaisesRegexp(errors.AnsibleError, + '.*Could not read vault password file.*/dev/null/foobar.*Not a directory', + secret.load) + + def test_file_not_found(self): + tmp_file = tempfile.NamedTemporaryFile() + filename = os.path.realpath(tmp_file.name) + tmp_file.close() + + fake_loader = DictDataLoader({filename: 'sdfadf'}) + + secret = vault.FileVaultSecret(loader=fake_loader, filename=filename) + self.assertRaisesRegexp(errors.AnsibleError, + '.*Could not read vault password file.*%s.*' % filename, + secret.load) + + +class TestScriptVaultSecret(unittest.TestCase): + def test(self): + secret = vault.ScriptVaultSecret() + self.assertIsNone(secret._bytes) + self.assertIsNone(secret._text) + + def _mock_popen(self, mock_popen, return_code=0, stdout=b'', stderr=b''): + def communicate(): + return stdout, stderr + mock_popen.return_value = MagicMock(returncode=return_code) + mock_popen_instance = mock_popen.return_value + mock_popen_instance.communicate = communicate + + @patch('ansible.parsing.vault.subprocess.Popen') + def test_read_file(self, mock_popen): + self._mock_popen(mock_popen, stdout=b'some_password') + secret = vault.ScriptVaultSecret() + with patch.object(secret, 'loader') as mock_loader: + mock_loader.is_executable = MagicMock(return_value=True) + secret.load() + + @patch('ansible.parsing.vault.subprocess.Popen') + def test_read_file_empty(self, mock_popen): + self._mock_popen(mock_popen, stdout=b'') + secret = vault.ScriptVaultSecret() + with patch.object(secret, 'loader') as mock_loader: + mock_loader.is_executable = MagicMock(return_value=True) + self.assertRaisesRegexp(vault.AnsibleVaultPasswordError, + 'Invalid vault password was provided from script', + secret.load) + + @patch('ansible.parsing.vault.subprocess.Popen') + def test_read_file_os_error(self, mock_popen): + self._mock_popen(mock_popen) + mock_popen.side_effect = OSError('That is not an executable') + secret = vault.ScriptVaultSecret() + with patch.object(secret, 'loader') as mock_loader: + mock_loader.is_executable = MagicMock(return_value=True) + self.assertRaisesRegexp(errors.AnsibleError, + 'Problem running vault password script.*', + secret.load) + + @patch('ansible.parsing.vault.subprocess.Popen') + def test_read_file_not_executable(self, mock_popen): + self._mock_popen(mock_popen) + secret = vault.ScriptVaultSecret() + with patch.object(secret, 'loader') as mock_loader: + mock_loader.is_executable = MagicMock(return_value=False) + self.assertRaisesRegexp(vault.AnsibleVaultError, + 'The vault password script .* was not executable', + secret.load) + + @patch('ansible.parsing.vault.subprocess.Popen') + def test_read_file_non_zero_return_code(self, mock_popen): + stderr = b'That did not work for a random reason' + rc = 37 + + self._mock_popen(mock_popen, return_code=rc, stderr=stderr) + secret = vault.ScriptVaultSecret(filename='/dev/null/some_vault_secret') + with patch.object(secret, 'loader') as mock_loader: + mock_loader.is_executable = MagicMock(return_value=True) + self.assertRaisesRegexp(errors.AnsibleError, + r'Vault password script.*returned non-zero \(%s\): %s' % (rc, stderr), + secret.load) + + +class TestScriptIsClient(unittest.TestCase): + def test_randomname(self): + filename = 'randomname' + res = vault.script_is_client(filename) + self.assertFalse(res) + + def test_something_dash_client(self): + filename = 'something-client' + res = vault.script_is_client(filename) + self.assertTrue(res) + + def test_something_dash_client_somethingelse(self): + filename = 'something-client-somethingelse' + res = vault.script_is_client(filename) + self.assertFalse(res) + + def test_something_dash_client_py(self): + filename = 'something-client.py' + res = vault.script_is_client(filename) + self.assertTrue(res) + + def test_full_path_something_dash_client_py(self): + filename = '/foo/bar/something-client.py' + res = vault.script_is_client(filename) + self.assertTrue(res) + + def test_full_path_something_dash_client(self): + filename = '/foo/bar/something-client' + res = vault.script_is_client(filename) + self.assertTrue(res) + + def test_full_path_something_dash_client_in_dir(self): + filename = '/foo/bar/something-client/but/not/filename' + res = vault.script_is_client(filename) + self.assertFalse(res) + + +class TestGetFileVaultSecret(unittest.TestCase): + def test_file(self): + password = 'some password' + + tmp_file = tempfile.NamedTemporaryFile(delete=False) + tmp_file.write(to_bytes(password)) + tmp_file.close() + + fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) + + secret = vault.get_file_vault_secret(filename=tmp_file.name, loader=fake_loader) + secret.load() + + os.unlink(tmp_file.name) + + self.assertEqual(secret.bytes, to_bytes(password)) + + def test_file_not_a_directory(self): + filename = '/dev/null/foobar' + fake_loader = DictDataLoader({filename: 'sdfadf'}) + + self.assertRaisesRegexp(errors.AnsibleError, + '.*The vault password file %s was not found.*' % filename, + vault.get_file_vault_secret, + filename=filename, + loader=fake_loader) + + def test_file_not_found(self): + tmp_file = tempfile.NamedTemporaryFile() + filename = os.path.realpath(tmp_file.name) + tmp_file.close() + + fake_loader = DictDataLoader({filename: 'sdfadf'}) + + self.assertRaisesRegexp(errors.AnsibleError, + '.*The vault password file %s was not found.*' % filename, + vault.get_file_vault_secret, + filename=filename, + loader=fake_loader) + + +class TestVaultIsEncrypted(unittest.TestCase): + def test_bytes_not_encrypted(self): + b_data = b"foobar" + self.assertFalse(vault.is_encrypted(b_data)) + + def test_bytes_encrypted(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + self.assertTrue(vault.is_encrypted(b_data)) + + def test_text_not_encrypted(self): + b_data = to_text(b"foobar") + self.assertFalse(vault.is_encrypted(b_data)) + + def test_text_encrypted(self): + b_data = to_text(b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")) + self.assertTrue(vault.is_encrypted(b_data)) + + def test_invalid_text_not_ascii(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " + self.assertFalse(vault.is_encrypted(data)) + + def test_invalid_bytes_not_ascii(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " + b_data = to_bytes(data, encoding='utf-8') + self.assertFalse(vault.is_encrypted(b_data)) + + +class TestVaultIsEncryptedFile(unittest.TestCase): + def test_binary_file_handle_not_encrypted(self): + b_data = b"foobar" + b_data_fo = io.BytesIO(b_data) + self.assertFalse(vault.is_encrypted_file(b_data_fo)) + + def test_text_file_handle_not_encrypted(self): + data = u"foobar" + data_fo = io.StringIO(data) + self.assertFalse(vault.is_encrypted_file(data_fo)) + + def test_binary_file_handle_encrypted(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") + b_data_fo = io.BytesIO(b_data) + self.assertTrue(vault.is_encrypted_file(b_data_fo)) + + def test_text_file_handle_encrypted(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % to_text(hexlify(b"ansible")) + data_fo = io.StringIO(data) + self.assertTrue(vault.is_encrypted_file(data_fo)) + + def test_binary_file_handle_invalid(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " + b_data = to_bytes(data) + b_data_fo = io.BytesIO(b_data) + self.assertFalse(vault.is_encrypted_file(b_data_fo)) + + def test_text_file_handle_invalid(self): + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " + data_fo = io.StringIO(data) + self.assertFalse(vault.is_encrypted_file(data_fo)) + + def test_file_already_read_from_finds_header(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") + b_data_fo = io.BytesIO(b_data) + b_data_fo.read(42) # Arbitrary number + self.assertTrue(vault.is_encrypted_file(b_data_fo)) + + def test_file_already_read_from_saves_file_pos(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") + b_data_fo = io.BytesIO(b_data) + b_data_fo.read(69) # Arbitrary number + vault.is_encrypted_file(b_data_fo) + self.assertEqual(b_data_fo.tell(), 69) + + def test_file_with_offset(self): + b_data = b"JUNK$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") + b_data_fo = io.BytesIO(b_data) + self.assertTrue(vault.is_encrypted_file(b_data_fo, start_pos=4)) + + def test_file_with_count(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") + vault_length = len(b_data) + b_data = b_data + u'ァ ア'.encode('utf-8') + b_data_fo = io.BytesIO(b_data) + self.assertTrue(vault.is_encrypted_file(b_data_fo, count=vault_length)) + + def test_file_with_offset_and_count(self): + b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") + vault_length = len(b_data) + b_data = b'JUNK' + b_data + u'ァ ア'.encode('utf-8') + b_data_fo = io.BytesIO(b_data) + self.assertTrue(vault.is_encrypted_file(b_data_fo, start_pos=4, count=vault_length)) + + +@pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY, + reason="Skipping cryptography tests because cryptography is not installed") +class TestVaultCipherAes256(unittest.TestCase): + def setUp(self): + self.vault_cipher = vault.VaultAES256() + + def test(self): + self.assertIsInstance(self.vault_cipher, vault.VaultAES256) + + # TODO: tag these as slow tests + def test_create_key_cryptography(self): + b_password = b'hunter42' + b_salt = os.urandom(32) + b_key_cryptography = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_cryptography, six.binary_type) + + @pytest.mark.skipif(not vault.HAS_PYCRYPTO, reason='Not testing pycrypto key as pycrypto is not installed') + def test_create_key_pycrypto(self): + b_password = b'hunter42' + b_salt = os.urandom(32) + + b_key_pycrypto = self.vault_cipher._create_key_pycrypto(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_pycrypto, six.binary_type) + + @pytest.mark.skipif(not vault.HAS_PYCRYPTO, + reason='Not comparing cryptography key to pycrypto key as pycrypto is not installed') + def test_compare_new_keys(self): + b_password = b'hunter42' + b_salt = os.urandom(32) + b_key_cryptography = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) + + b_key_pycrypto = self.vault_cipher._create_key_pycrypto(b_password, b_salt, key_length=32, iv_length=16) + self.assertEqual(b_key_cryptography, b_key_pycrypto) + + def test_create_key_known_cryptography(self): + b_password = b'hunter42' + + # A fixed salt + b_salt = b'q' * 32 # q is the most random letter. + b_key_1 = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_1, six.binary_type) + + # verify we get the same answer + # we could potentially run a few iterations of this and time it to see if it's roughly constant time + # and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI + b_key_2 = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_2, six.binary_type) + self.assertEqual(b_key_1, b_key_2) + + # And again with pycrypto + b_key_3 = self.vault_cipher._create_key_pycrypto(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_3, six.binary_type) + + # verify we get the same answer + # we could potentially run a few iterations of this and time it to see if it's roughly constant time + # and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI + b_key_4 = self.vault_cipher._create_key_pycrypto(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_4, six.binary_type) + self.assertEqual(b_key_3, b_key_4) + self.assertEqual(b_key_1, b_key_4) + + def test_create_key_known_pycrypto(self): + b_password = b'hunter42' + + # A fixed salt + b_salt = b'q' * 32 # q is the most random letter. + b_key_3 = self.vault_cipher._create_key_pycrypto(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_3, six.binary_type) + + # verify we get the same answer + # we could potentially run a few iterations of this and time it to see if it's roughly constant time + # and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI + b_key_4 = self.vault_cipher._create_key_pycrypto(b_password, b_salt, key_length=32, iv_length=16) + self.assertIsInstance(b_key_4, six.binary_type) + self.assertEqual(b_key_3, b_key_4) + + def test_is_equal_is_equal(self): + self.assertTrue(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwxyz')) + + def test_is_equal_unequal_length(self): + self.assertFalse(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwx and sometimes y')) + + def test_is_equal_not_equal(self): + self.assertFalse(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'AbcdefghijKlmnopQrstuvwxZ')) + + def test_is_equal_empty(self): + self.assertTrue(self.vault_cipher._is_equal(b'', b'')) + + def test_is_equal_non_ascii_equal(self): + utf8_data = to_bytes(u'私はガラスを食べられます。それは私を傷つけません。') + self.assertTrue(self.vault_cipher._is_equal(utf8_data, utf8_data)) + + def test_is_equal_non_ascii_unequal(self): + utf8_data = to_bytes(u'私はガラスを食べられます。それは私を傷つけません。') + utf8_data2 = to_bytes(u'Pot să mănânc sticlă și ea nu mă rănește.') + + # Test for the len optimization path + self.assertFalse(self.vault_cipher._is_equal(utf8_data, utf8_data2)) + # Test for the slower, char by char comparison path + self.assertFalse(self.vault_cipher._is_equal(utf8_data, utf8_data[:-1] + b'P')) + + def test_is_equal_non_bytes(self): + """ Anything not a byte string should raise a TypeError """ + self.assertRaises(TypeError, self.vault_cipher._is_equal, u"One fish", b"two fish") + self.assertRaises(TypeError, self.vault_cipher._is_equal, b"One fish", u"two fish") + self.assertRaises(TypeError, self.vault_cipher._is_equal, 1, b"red fish") + self.assertRaises(TypeError, self.vault_cipher._is_equal, b"blue fish", 2) + + +@pytest.mark.skipif(not vault.HAS_PYCRYPTO, + reason="Skipping Pycrypto tests because pycrypto is not installed") +class TestVaultCipherAes256PyCrypto(TestVaultCipherAes256): + def setUp(self): + self.has_cryptography = vault.HAS_CRYPTOGRAPHY + vault.HAS_CRYPTOGRAPHY = False + super(TestVaultCipherAes256PyCrypto, self).setUp() + + def tearDown(self): + vault.HAS_CRYPTOGRAPHY = self.has_cryptography + super(TestVaultCipherAes256PyCrypto, self).tearDown() + + +class TestMatchSecrets(unittest.TestCase): + def test_empty_tuple(self): + secrets = [tuple()] + vault_ids = ['vault_id_1'] + self.assertRaises(ValueError, + vault.match_secrets, + secrets, vault_ids) + + def test_empty_secrets(self): + matches = vault.match_secrets([], ['vault_id_1']) + self.assertEqual(matches, []) + + def test_single_match(self): + secret = TextVaultSecret('password') + matches = vault.match_secrets([('default', secret)], ['default']) + self.assertEqual(matches, [('default', secret)]) + + def test_no_matches(self): + secret = TextVaultSecret('password') + matches = vault.match_secrets([('default', secret)], ['not_default']) + self.assertEqual(matches, []) + + def test_multiple_matches(self): + secrets = [('vault_id1', TextVaultSecret('password1')), + ('vault_id2', TextVaultSecret('password2')), + ('vault_id1', TextVaultSecret('password3')), + ('vault_id4', TextVaultSecret('password4'))] + vault_ids = ['vault_id1', 'vault_id4'] + matches = vault.match_secrets(secrets, vault_ids) + + self.assertEqual(len(matches), 3) + expected = [('vault_id1', TextVaultSecret('password1')), + ('vault_id1', TextVaultSecret('password3')), + ('vault_id4', TextVaultSecret('password4'))] + self.assertEqual([x for x, y in matches], + [a for a, b in expected]) + + +@pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY, + reason="Skipping cryptography tests because cryptography is not installed") +class TestVaultLib(unittest.TestCase): + def setUp(self): + self.vault_password = "test-vault-password" + text_secret = TextVaultSecret(self.vault_password) + self.vault_secrets = [('default', text_secret), + ('test_id', text_secret)] + self.v = vault.VaultLib(self.vault_secrets) + + def _vault_secrets(self, vault_id, secret): + return [(vault_id, secret)] + + def _vault_secrets_from_password(self, vault_id, password): + return [(vault_id, TextVaultSecret(password))] + + def test_encrypt(self): + plaintext = u'Some text to encrypt in a café' + b_vaulttext = self.v.encrypt(plaintext) + + self.assertIsInstance(b_vaulttext, six.binary_type) + + b_header = b'$ANSIBLE_VAULT;1.1;AES256\n' + self.assertEqual(b_vaulttext[:len(b_header)], b_header) + + def test_encrypt_vault_id(self): + plaintext = u'Some text to encrypt in a café' + b_vaulttext = self.v.encrypt(plaintext, vault_id='test_id') + + self.assertIsInstance(b_vaulttext, six.binary_type) + + b_header = b'$ANSIBLE_VAULT;1.2;AES256;test_id\n' + self.assertEqual(b_vaulttext[:len(b_header)], b_header) + + def test_encrypt_bytes(self): + + plaintext = to_bytes(u'Some text to encrypt in a café') + b_vaulttext = self.v.encrypt(plaintext) + + self.assertIsInstance(b_vaulttext, six.binary_type) + + b_header = b'$ANSIBLE_VAULT;1.1;AES256\n' + self.assertEqual(b_vaulttext[:len(b_header)], b_header) + + def test_encrypt_no_secret_empty_secrets(self): + vault_secrets = [] + v = vault.VaultLib(vault_secrets) + + plaintext = u'Some text to encrypt in a café' + self.assertRaisesRegexp(vault.AnsibleVaultError, + '.*A vault password must be specified to encrypt data.*', + v.encrypt, + plaintext) + + def test_format_vaulttext_envelope(self): + cipher_name = "TEST" + b_ciphertext = b"ansible" + b_vaulttext = vault.format_vaulttext_envelope(b_ciphertext, + cipher_name, + version=self.v.b_version, + vault_id='default') + b_lines = b_vaulttext.split(b'\n') + self.assertGreater(len(b_lines), 1, msg="failed to properly add header") + + b_header = b_lines[0] + # self.assertTrue(b_header.endswith(b';TEST'), msg="header does not end with cipher name") + + b_header_parts = b_header.split(b';') + self.assertEqual(len(b_header_parts), 4, msg="header has the wrong number of parts") + self.assertEqual(b_header_parts[0], b'$ANSIBLE_VAULT', msg="header does not start with $ANSIBLE_VAULT") + self.assertEqual(b_header_parts[1], self.v.b_version, msg="header version is incorrect") + self.assertEqual(b_header_parts[2], b'TEST', msg="header does not end with cipher name") + + # And just to verify, lets parse the results and compare + b_ciphertext2, b_version2, cipher_name2, vault_id2 = \ + vault.parse_vaulttext_envelope(b_vaulttext) + self.assertEqual(b_ciphertext, b_ciphertext2) + self.assertEqual(self.v.b_version, b_version2) + self.assertEqual(cipher_name, cipher_name2) + self.assertEqual('default', vault_id2) + + def test_parse_vaulttext_envelope(self): + b_vaulttext = b"$ANSIBLE_VAULT;9.9;TEST\nansible" + b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext) + b_lines = b_ciphertext.split(b'\n') + self.assertEqual(b_lines[0], b"ansible", msg="Payload was not properly split from the header") + self.assertEqual(cipher_name, u'TEST', msg="cipher name was not properly set") + self.assertEqual(b_version, b"9.9", msg="version was not properly set") + + def test_parse_vaulttext_envelope_crlf(self): + b_vaulttext = b"$ANSIBLE_VAULT;9.9;TEST\r\nansible" + b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext) + b_lines = b_ciphertext.split(b'\n') + self.assertEqual(b_lines[0], b"ansible", msg="Payload was not properly split from the header") + self.assertEqual(cipher_name, u'TEST', msg="cipher name was not properly set") + self.assertEqual(b_version, b"9.9", msg="version was not properly set") + + def test_encrypt_decrypt_aes256(self): + self.v.cipher_name = u'AES256' + plaintext = u"foobar" + b_vaulttext = self.v.encrypt(plaintext) + b_plaintext = self.v.decrypt(b_vaulttext) + self.assertNotEqual(b_vaulttext, b"foobar", msg="encryption failed") + self.assertEqual(b_plaintext, b"foobar", msg="decryption failed") + + def test_encrypt_decrypt_aes256_none_secrets(self): + vault_secrets = self._vault_secrets_from_password('default', 'ansible') + v = vault.VaultLib(vault_secrets) + + plaintext = u"foobar" + b_vaulttext = v.encrypt(plaintext) + + # VaultLib will default to empty {} if secrets is None + v_none = vault.VaultLib(None) + # so set secrets None explicitly + v_none.secrets = None + self.assertRaisesRegexp(vault.AnsibleVaultError, + '.*A vault password must be specified to decrypt data.*', + v_none.decrypt, + b_vaulttext) + + def test_encrypt_decrypt_aes256_empty_secrets(self): + vault_secrets = self._vault_secrets_from_password('default', 'ansible') + v = vault.VaultLib(vault_secrets) + + plaintext = u"foobar" + b_vaulttext = v.encrypt(plaintext) + + vault_secrets_empty = [] + v_none = vault.VaultLib(vault_secrets_empty) + + self.assertRaisesRegexp(vault.AnsibleVaultError, + '.*Attempting to decrypt but no vault secrets found.*', + v_none.decrypt, + b_vaulttext) + + def test_encrypt_decrypt_aes256_multiple_secrets_all_wrong(self): + plaintext = u'Some text to encrypt in a café' + b_vaulttext = self.v.encrypt(plaintext) + + vault_secrets = [('default', TextVaultSecret('another-wrong-password')), + ('wrong-password', TextVaultSecret('wrong-password'))] + + v_multi = vault.VaultLib(vault_secrets) + self.assertRaisesRegexp(errors.AnsibleError, + '.*Decryption failed.*', + v_multi.decrypt, + b_vaulttext, + filename='/dev/null/fake/filename') + + def test_encrypt_decrypt_aes256_multiple_secrets_one_valid(self): + plaintext = u'Some text to encrypt in a café' + b_vaulttext = self.v.encrypt(plaintext) + + correct_secret = TextVaultSecret(self.vault_password) + wrong_secret = TextVaultSecret('wrong-password') + + vault_secrets = [('default', wrong_secret), + ('corect_secret', correct_secret), + ('wrong_secret', wrong_secret)] + + v_multi = vault.VaultLib(vault_secrets) + b_plaintext = v_multi.decrypt(b_vaulttext) + self.assertNotEqual(b_vaulttext, to_bytes(plaintext), msg="encryption failed") + self.assertEqual(b_plaintext, to_bytes(plaintext), msg="decryption failed") + + def test_encrypt_decrypt_aes256_existing_vault(self): + self.v.cipher_name = u'AES256' + b_orig_plaintext = b"Setec Astronomy" + vaulttext = u'''$ANSIBLE_VAULT;1.1;AES256 +33363965326261303234626463623963633531343539616138316433353830356566396130353436 +3562643163366231316662386565383735653432386435610a306664636137376132643732393835 +63383038383730306639353234326630666539346233376330303938323639306661313032396437 +6233623062366136310a633866373936313238333730653739323461656662303864663666653563 +3138''' + + b_plaintext = self.v.decrypt(vaulttext) + self.assertEqual(b_plaintext, b_plaintext, msg="decryption failed") + + b_vaulttext = to_bytes(vaulttext, encoding='ascii', errors='strict') + b_plaintext = self.v.decrypt(b_vaulttext) + self.assertEqual(b_plaintext, b_orig_plaintext, msg="decryption failed") + + # FIXME This test isn't working quite yet. + @pytest.mark.skip(reason='This test is not ready yet') + def test_encrypt_decrypt_aes256_bad_hmac(self): + + self.v.cipher_name = 'AES256' + # plaintext = "Setec Astronomy" + enc_data = '''$ANSIBLE_VAULT;1.1;AES256 +33363965326261303234626463623963633531343539616138316433353830356566396130353436 +3562643163366231316662386565383735653432386435610a306664636137376132643732393835 +63383038383730306639353234326630666539346233376330303938323639306661313032396437 +6233623062366136310a633866373936313238333730653739323461656662303864663666653563 +3138''' + b_data = to_bytes(enc_data, errors='strict', encoding='utf-8') + b_data = self.v._split_header(b_data) + foo = binascii.unhexlify(b_data) + lines = foo.splitlines() + # line 0 is salt, line 1 is hmac, line 2+ is ciphertext + b_salt = lines[0] + b_hmac = lines[1] + b_ciphertext_data = b'\n'.join(lines[2:]) + + b_ciphertext = binascii.unhexlify(b_ciphertext_data) + # b_orig_ciphertext = b_ciphertext[:] + + # now muck with the text + # b_munged_ciphertext = b_ciphertext[:10] + b'\x00' + b_ciphertext[11:] + # b_munged_ciphertext = b_ciphertext + # assert b_orig_ciphertext != b_munged_ciphertext + + b_ciphertext_data = binascii.hexlify(b_ciphertext) + b_payload = b'\n'.join([b_salt, b_hmac, b_ciphertext_data]) + # reformat + b_invalid_ciphertext = self.v._format_output(b_payload) + + # assert we throw an error + self.v.decrypt(b_invalid_ciphertext) + + def test_decrypt_and_get_vault_id(self): + b_expected_plaintext = to_bytes('foo bar\n') + vaulttext = '''$ANSIBLE_VAULT;1.2;AES256;ansible_devel +65616435333934613466373335363332373764363365633035303466643439313864663837393234 +3330656363343637313962633731333237313636633534630a386264363438363362326132363239 +39363166646664346264383934393935653933316263333838386362633534326664646166663736 +6462303664383765650a356637643633366663643566353036303162386237336233393065393164 +6264''' + + vault_secrets = self._vault_secrets_from_password('ansible_devel', 'ansible') + v = vault.VaultLib(vault_secrets) + + b_vaulttext = to_bytes(vaulttext) + + b_plaintext, vault_id_used, vault_secret_used = v.decrypt_and_get_vault_id(b_vaulttext) + + self.assertEqual(b_expected_plaintext, b_plaintext) + self.assertEqual(vault_id_used, 'ansible_devel') + self.assertEqual(vault_secret_used.text, 'ansible') + + def test_decrypt_non_default_1_2(self): + b_expected_plaintext = to_bytes('foo bar\n') + vaulttext = '''$ANSIBLE_VAULT;1.2;AES256;ansible_devel +65616435333934613466373335363332373764363365633035303466643439313864663837393234 +3330656363343637313962633731333237313636633534630a386264363438363362326132363239 +39363166646664346264383934393935653933316263333838386362633534326664646166663736 +6462303664383765650a356637643633366663643566353036303162386237336233393065393164 +6264''' + + vault_secrets = self._vault_secrets_from_password('default', 'ansible') + v = vault.VaultLib(vault_secrets) + + b_vaulttext = to_bytes(vaulttext) + + b_plaintext = v.decrypt(b_vaulttext) + self.assertEqual(b_expected_plaintext, b_plaintext) + + b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext) + self.assertEqual('ansible_devel', vault_id) + self.assertEqual(b'1.2', b_version) + + def test_decrypt_decrypted(self): + plaintext = u"ansible" + self.assertRaises(errors.AnsibleError, self.v.decrypt, plaintext) + + b_plaintext = b"ansible" + self.assertRaises(errors.AnsibleError, self.v.decrypt, b_plaintext) + + def test_cipher_not_set(self): + plaintext = u"ansible" + self.v.encrypt(plaintext) + self.assertEqual(self.v.cipher_name, "AES256") + + +@pytest.mark.skipif(not vault.HAS_PYCRYPTO, + reason="Skipping Pycrypto tests because pycrypto is not installed") +class TestVaultLibPyCrypto(TestVaultLib): + def setUp(self): + self.has_cryptography = vault.HAS_CRYPTOGRAPHY + vault.HAS_CRYPTOGRAPHY = False + super(TestVaultLibPyCrypto, self).setUp() + + def tearDown(self): + vault.HAS_CRYPTOGRAPHY = self.has_cryptography + super(TestVaultLibPyCrypto, self).tearDown() diff --git a/test/units/parsing/vault/test_vault_editor.py b/test/units/parsing/vault/test_vault_editor.py new file mode 100644 index 00000000..8aa9b37c --- /dev/null +++ b/test/units/parsing/vault/test_vault_editor.py @@ -0,0 +1,517 @@ +# (c) 2014, James Tanner <tanner.jc@gmail.com> +# (c) 2014, James Cammarata, <jcammarata@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import tempfile + +import pytest + +from units.compat import unittest +from units.compat.mock import patch + +from ansible import errors +from ansible.parsing import vault +from ansible.parsing.vault import VaultLib, VaultEditor, match_encrypt_secret + +from ansible.module_utils._text import to_bytes, to_text + +from units.mock.vault_helper import TextVaultSecret + +v11_data = """$ANSIBLE_VAULT;1.1;AES256 +62303130653266653331306264616235333735323636616539316433666463323964623162386137 +3961616263373033353631316333623566303532663065310a393036623466376263393961326530 +64336561613965383835646464623865663966323464653236343638373165343863623638316664 +3631633031323837340a396530313963373030343933616133393566366137363761373930663833 +3739""" + + +@pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY, + reason="Skipping cryptography tests because cryptography is not installed") +class TestVaultEditor(unittest.TestCase): + + def setUp(self): + self._test_dir = None + self.vault_password = "test-vault-password" + vault_secret = TextVaultSecret(self.vault_password) + self.vault_secrets = [('vault_secret', vault_secret), + ('default', vault_secret)] + + @property + def vault_secret(self): + return match_encrypt_secret(self.vault_secrets)[1] + + def tearDown(self): + if self._test_dir: + pass + # shutil.rmtree(self._test_dir) + self._test_dir = None + + def _secrets(self, password): + vault_secret = TextVaultSecret(password) + vault_secrets = [('default', vault_secret)] + return vault_secrets + + def test_methods_exist(self): + v = vault.VaultEditor(None) + slots = ['create_file', + 'decrypt_file', + 'edit_file', + 'encrypt_file', + 'rekey_file', + 'read_data', + 'write_data'] + for slot in slots: + assert hasattr(v, slot), "VaultLib is missing the %s method" % slot + + def _create_test_dir(self): + suffix = '_ansible_unit_test_%s_' % (self.__class__.__name__) + return tempfile.mkdtemp(suffix=suffix) + + def _create_file(self, test_dir, name, content=None, symlink=False): + file_path = os.path.join(test_dir, name) + opened_file = open(file_path, 'wb') + if content: + opened_file.write(content) + opened_file.close() + return file_path + + def _vault_editor(self, vault_secrets=None): + if vault_secrets is None: + vault_secrets = self._secrets(self.vault_password) + return VaultEditor(VaultLib(vault_secrets)) + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_helper_empty_target(self, mock_sp_call): + self._test_dir = self._create_test_dir() + + src_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + mock_sp_call.side_effect = self._faux_command + ve = self._vault_editor() + + b_ciphertext = ve._edit_file_helper(src_file_path, self.vault_secret) + + self.assertNotEqual(src_contents, b_ciphertext) + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_helper_call_exception(self, mock_sp_call): + self._test_dir = self._create_test_dir() + + src_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + error_txt = 'calling editor raised an exception' + mock_sp_call.side_effect = errors.AnsibleError(error_txt) + + ve = self._vault_editor() + + self.assertRaisesRegexp(errors.AnsibleError, + error_txt, + ve._edit_file_helper, + src_file_path, + self.vault_secret) + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_helper_symlink_target(self, mock_sp_call): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + src_file_link_path = os.path.join(self._test_dir, 'a_link_to_dest_file') + + os.symlink(src_file_path, src_file_link_path) + + mock_sp_call.side_effect = self._faux_command + ve = self._vault_editor() + + b_ciphertext = ve._edit_file_helper(src_file_link_path, self.vault_secret) + + self.assertNotEqual(src_file_contents, b_ciphertext, + 'b_ciphertext should be encrypted and not equal to src_contents') + + def _faux_editor(self, editor_args, new_src_contents=None): + if editor_args[0] == 'shred': + return + + tmp_path = editor_args[-1] + + # simulate the tmp file being editted + tmp_file = open(tmp_path, 'wb') + if new_src_contents: + tmp_file.write(new_src_contents) + tmp_file.close() + + def _faux_command(self, tmp_path): + pass + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_helper_no_change(self, mock_sp_call): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + # editor invocation doesn't change anything + def faux_editor(editor_args): + self._faux_editor(editor_args, src_file_contents) + + mock_sp_call.side_effect = faux_editor + ve = self._vault_editor() + + ve._edit_file_helper(src_file_path, self.vault_secret, existing_data=src_file_contents) + + new_target_file = open(src_file_path, 'rb') + new_target_file_contents = new_target_file.read() + self.assertEqual(src_file_contents, new_target_file_contents) + + def _assert_file_is_encrypted(self, vault_editor, src_file_path, src_contents): + new_src_file = open(src_file_path, 'rb') + new_src_file_contents = new_src_file.read() + + # TODO: assert that it is encrypted + self.assertTrue(vault.is_encrypted(new_src_file_contents)) + + src_file_plaintext = vault_editor.vault.decrypt(new_src_file_contents) + + # the plaintext should not be encrypted + self.assertFalse(vault.is_encrypted(src_file_plaintext)) + + # and the new plaintext should match the original + self.assertEqual(src_file_plaintext, src_contents) + + def _assert_file_is_link(self, src_file_link_path, src_file_path): + self.assertTrue(os.path.islink(src_file_link_path), + 'The dest path (%s) should be a symlink to (%s) but is not' % (src_file_link_path, src_file_path)) + + def test_rekey_file(self): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + ve = self._vault_editor() + ve.encrypt_file(src_file_path, self.vault_secret) + + # FIXME: update to just set self._secrets or just a new vault secret id + new_password = 'password2:electricbugaloo' + new_vault_secret = TextVaultSecret(new_password) + new_vault_secrets = [('default', new_vault_secret)] + ve.rekey_file(src_file_path, vault.match_encrypt_secret(new_vault_secrets)[1]) + + # FIXME: can just update self._secrets here + new_ve = vault.VaultEditor(VaultLib(new_vault_secrets)) + self._assert_file_is_encrypted(new_ve, src_file_path, src_file_contents) + + def test_rekey_file_no_new_password(self): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + ve = self._vault_editor() + ve.encrypt_file(src_file_path, self.vault_secret) + + self.assertRaisesRegexp(errors.AnsibleError, + 'The value for the new_password to rekey', + ve.rekey_file, + src_file_path, + None) + + def test_rekey_file_not_encrypted(self): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + ve = self._vault_editor() + + new_password = 'password2:electricbugaloo' + self.assertRaisesRegexp(errors.AnsibleError, + 'input is not vault encrypted data', + ve.rekey_file, + src_file_path, new_password) + + def test_plaintext(self): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + ve = self._vault_editor() + ve.encrypt_file(src_file_path, self.vault_secret) + + res = ve.plaintext(src_file_path) + self.assertEqual(src_file_contents, res) + + def test_plaintext_not_encrypted(self): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + ve = self._vault_editor() + self.assertRaisesRegexp(errors.AnsibleError, + 'input is not vault encrypted data', + ve.plaintext, + src_file_path) + + def test_encrypt_file(self): + self._test_dir = self._create_test_dir() + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + ve = self._vault_editor() + ve.encrypt_file(src_file_path, self.vault_secret) + + self._assert_file_is_encrypted(ve, src_file_path, src_file_contents) + + def test_encrypt_file_symlink(self): + self._test_dir = self._create_test_dir() + + src_file_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_file_contents) + + src_file_link_path = os.path.join(self._test_dir, 'a_link_to_dest_file') + os.symlink(src_file_path, src_file_link_path) + + ve = self._vault_editor() + ve.encrypt_file(src_file_link_path, self.vault_secret) + + self._assert_file_is_encrypted(ve, src_file_path, src_file_contents) + self._assert_file_is_encrypted(ve, src_file_link_path, src_file_contents) + + self._assert_file_is_link(src_file_link_path, src_file_path) + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_no_vault_id(self, mock_sp_call): + self._test_dir = self._create_test_dir() + src_contents = to_bytes("some info in a file\nyup.") + + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + new_src_contents = to_bytes("The info is different now.") + + def faux_editor(editor_args): + self._faux_editor(editor_args, new_src_contents) + + mock_sp_call.side_effect = faux_editor + + ve = self._vault_editor() + + ve.encrypt_file(src_file_path, self.vault_secret) + ve.edit_file(src_file_path) + + new_src_file = open(src_file_path, 'rb') + new_src_file_contents = new_src_file.read() + + self.assertTrue(b'$ANSIBLE_VAULT;1.1;AES256' in new_src_file_contents) + + src_file_plaintext = ve.vault.decrypt(new_src_file_contents) + self.assertEqual(src_file_plaintext, new_src_contents) + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_with_vault_id(self, mock_sp_call): + self._test_dir = self._create_test_dir() + src_contents = to_bytes("some info in a file\nyup.") + + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + new_src_contents = to_bytes("The info is different now.") + + def faux_editor(editor_args): + self._faux_editor(editor_args, new_src_contents) + + mock_sp_call.side_effect = faux_editor + + ve = self._vault_editor() + + ve.encrypt_file(src_file_path, self.vault_secret, + vault_id='vault_secrets') + ve.edit_file(src_file_path) + + new_src_file = open(src_file_path, 'rb') + new_src_file_contents = new_src_file.read() + + self.assertTrue(b'$ANSIBLE_VAULT;1.2;AES256;vault_secrets' in new_src_file_contents) + + src_file_plaintext = ve.vault.decrypt(new_src_file_contents) + self.assertEqual(src_file_plaintext, new_src_contents) + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_symlink(self, mock_sp_call): + self._test_dir = self._create_test_dir() + src_contents = to_bytes("some info in a file\nyup.") + + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + new_src_contents = to_bytes("The info is different now.") + + def faux_editor(editor_args): + self._faux_editor(editor_args, new_src_contents) + + mock_sp_call.side_effect = faux_editor + + ve = self._vault_editor() + + ve.encrypt_file(src_file_path, self.vault_secret) + + src_file_link_path = os.path.join(self._test_dir, 'a_link_to_dest_file') + + os.symlink(src_file_path, src_file_link_path) + + ve.edit_file(src_file_link_path) + + new_src_file = open(src_file_path, 'rb') + new_src_file_contents = new_src_file.read() + + src_file_plaintext = ve.vault.decrypt(new_src_file_contents) + + self._assert_file_is_link(src_file_link_path, src_file_path) + + self.assertEqual(src_file_plaintext, new_src_contents) + + # self.assertEqual(src_file_plaintext, new_src_contents, + # 'The decrypted plaintext of the editted file is not the expected contents.') + + @patch('ansible.parsing.vault.subprocess.call') + def test_edit_file_not_encrypted(self, mock_sp_call): + self._test_dir = self._create_test_dir() + src_contents = to_bytes("some info in a file\nyup.") + + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + new_src_contents = to_bytes("The info is different now.") + + def faux_editor(editor_args): + self._faux_editor(editor_args, new_src_contents) + + mock_sp_call.side_effect = faux_editor + + ve = self._vault_editor() + self.assertRaisesRegexp(errors.AnsibleError, + 'input is not vault encrypted data', + ve.edit_file, + src_file_path) + + def test_create_file_exists(self): + self._test_dir = self._create_test_dir() + src_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + ve = self._vault_editor() + self.assertRaisesRegexp(errors.AnsibleError, + 'please use .edit. instead', + ve.create_file, + src_file_path, + self.vault_secret) + + def test_decrypt_file_exception(self): + self._test_dir = self._create_test_dir() + src_contents = to_bytes("some info in a file\nyup.") + src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) + + ve = self._vault_editor() + self.assertRaisesRegexp(errors.AnsibleError, + 'input is not vault encrypted data', + ve.decrypt_file, + src_file_path) + + @patch.object(vault.VaultEditor, '_editor_shell_command') + def test_create_file(self, mock_editor_shell_command): + + def sc_side_effect(filename): + return ['touch', filename] + mock_editor_shell_command.side_effect = sc_side_effect + + tmp_file = tempfile.NamedTemporaryFile() + os.unlink(tmp_file.name) + + _secrets = self._secrets('ansible') + ve = self._vault_editor(_secrets) + ve.create_file(tmp_file.name, vault.match_encrypt_secret(_secrets)[1]) + + self.assertTrue(os.path.exists(tmp_file.name)) + + def test_decrypt_1_1(self): + v11_file = tempfile.NamedTemporaryFile(delete=False) + with v11_file as f: + f.write(to_bytes(v11_data)) + + ve = self._vault_editor(self._secrets("ansible")) + + # make sure the password functions for the cipher + error_hit = False + try: + ve.decrypt_file(v11_file.name) + except errors.AnsibleError: + error_hit = True + + # verify decrypted content + f = open(v11_file.name, "rb") + fdata = to_text(f.read()) + f.close() + + os.unlink(v11_file.name) + + assert error_hit is False, "error decrypting 1.1 file" + assert fdata.strip() == "foo", "incorrect decryption of 1.1 file: %s" % fdata.strip() + + def test_real_path_dash(self): + filename = '-' + ve = self._vault_editor() + + res = ve._real_path(filename) + self.assertEqual(res, '-') + + def test_real_path_dev_null(self): + filename = '/dev/null' + ve = self._vault_editor() + + res = ve._real_path(filename) + self.assertEqual(res, '/dev/null') + + def test_real_path_symlink(self): + self._test_dir = os.path.realpath(self._create_test_dir()) + file_path = self._create_file(self._test_dir, 'test_file', content=b'this is a test file') + file_link_path = os.path.join(self._test_dir, 'a_link_to_test_file') + + os.symlink(file_path, file_link_path) + + ve = self._vault_editor() + + res = ve._real_path(file_link_path) + self.assertEqual(res, file_path) + + +@pytest.mark.skipif(not vault.HAS_PYCRYPTO, + reason="Skipping pycrypto tests because pycrypto is not installed") +class TestVaultEditorPyCrypto(unittest.TestCase): + def setUp(self): + self.has_cryptography = vault.HAS_CRYPTOGRAPHY + vault.HAS_CRYPTOGRAPHY = False + super(TestVaultEditorPyCrypto, self).setUp() + + def tearDown(self): + vault.HAS_CRYPTOGRAPHY = self.has_cryptography + super(TestVaultEditorPyCrypto, self).tearDown() diff --git a/test/units/parsing/yaml/__init__.py b/test/units/parsing/yaml/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/parsing/yaml/__init__.py diff --git a/test/units/parsing/yaml/test_dumper.py b/test/units/parsing/yaml/test_dumper.py new file mode 100644 index 00000000..8129ca3a --- /dev/null +++ b/test/units/parsing/yaml/test_dumper.py @@ -0,0 +1,103 @@ +# coding: utf-8 +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import io + +from units.compat import unittest +from ansible.parsing import vault +from ansible.parsing.yaml import dumper, objects +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.module_utils.six import PY2 +from ansible.utils.unsafe_proxy import AnsibleUnsafeText, AnsibleUnsafeBytes + +from units.mock.yaml_helper import YamlTestUtils +from units.mock.vault_helper import TextVaultSecret + + +class TestAnsibleDumper(unittest.TestCase, YamlTestUtils): + def setUp(self): + self.vault_password = "hunter42" + vault_secret = TextVaultSecret(self.vault_password) + self.vault_secrets = [('vault_secret', vault_secret)] + self.good_vault = vault.VaultLib(self.vault_secrets) + self.vault = self.good_vault + self.stream = self._build_stream() + self.dumper = dumper.AnsibleDumper + + def _build_stream(self, yaml_text=None): + text = yaml_text or u'' + stream = io.StringIO(text) + return stream + + def _loader(self, stream): + return AnsibleLoader(stream, vault_secrets=self.vault.secrets) + + def test_ansible_vault_encrypted_unicode(self): + plaintext = 'This is a string we are going to encrypt.' + avu = objects.AnsibleVaultEncryptedUnicode.from_plaintext(plaintext, vault=self.vault, + secret=vault.match_secrets(self.vault_secrets, ['vault_secret'])[0][1]) + + yaml_out = self._dump_string(avu, dumper=self.dumper) + stream = self._build_stream(yaml_out) + loader = self._loader(stream) + + data_from_yaml = loader.get_single_data() + + self.assertEqual(plaintext, data_from_yaml.data) + + def test_bytes(self): + b_text = u'tréma'.encode('utf-8') + unsafe_object = AnsibleUnsafeBytes(b_text) + yaml_out = self._dump_string(unsafe_object, dumper=self.dumper) + + stream = self._build_stream(yaml_out) + loader = self._loader(stream) + + data_from_yaml = loader.get_single_data() + + result = b_text + if PY2: + # https://pyyaml.org/wiki/PyYAMLDocumentation#string-conversion-python-2-only + # pyyaml on Python 2 can return either unicode or bytes when given byte strings. + # We normalize that to always return unicode on Python2 as that's right most of the + # time. However, this means byte strings can round trip through yaml on Python3 but + # not on Python2. To make this code work the same on Python2 and Python3 (we want + # the Python3 behaviour) we need to change the methods in Ansible to: + # (1) Let byte strings pass through yaml without being converted on Python2 + # (2) Convert byte strings to text strings before being given to pyyaml (Without this, + # strings would end up as byte strings most of the time which would mostly be wrong) + # In practice, we mostly read bytes in from files and then pass that to pyyaml, for which + # the present behavior is correct. + # This is a workaround for the current behavior. + result = u'tr\xe9ma' + + self.assertEqual(result, data_from_yaml) + + def test_unicode(self): + u_text = u'nöel' + unsafe_object = AnsibleUnsafeText(u_text) + yaml_out = self._dump_string(unsafe_object, dumper=self.dumper) + + stream = self._build_stream(yaml_out) + loader = self._loader(stream) + + data_from_yaml = loader.get_single_data() + + self.assertEqual(u_text, data_from_yaml) diff --git a/test/units/parsing/yaml/test_loader.py b/test/units/parsing/yaml/test_loader.py new file mode 100644 index 00000000..d6989f44 --- /dev/null +++ b/test/units/parsing/yaml/test_loader.py @@ -0,0 +1,436 @@ +# coding: utf-8 +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from io import StringIO + +from units.compat import unittest + +from ansible import errors +from ansible.module_utils.six import text_type, binary_type +from ansible.module_utils.common._collections_compat import Sequence, Set, Mapping +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.parsing import vault +from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode +from ansible.parsing.yaml.dumper import AnsibleDumper + +from units.mock.yaml_helper import YamlTestUtils +from units.mock.vault_helper import TextVaultSecret + +try: + from _yaml import ParserError + from _yaml import ScannerError +except ImportError: + from yaml.parser import ParserError + from yaml.scanner import ScannerError + + +class NameStringIO(StringIO): + """In py2.6, StringIO doesn't let you set name because a baseclass has it + as readonly property""" + name = None + + def __init__(self, *args, **kwargs): + super(NameStringIO, self).__init__(*args, **kwargs) + + +class TestAnsibleLoaderBasic(unittest.TestCase): + + def test_parse_number(self): + stream = StringIO(u""" + 1 + """) + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, 1) + # No line/column info saved yet + + def test_parse_string(self): + stream = StringIO(u""" + Ansible + """) + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, u'Ansible') + self.assertIsInstance(data, text_type) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) + + def test_parse_utf8_string(self): + stream = StringIO(u""" + Cafè Eñyei + """) + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, u'Cafè Eñyei') + self.assertIsInstance(data, text_type) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) + + def test_parse_dict(self): + stream = StringIO(u""" + webster: daniel + oed: oxford + """) + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, {'webster': 'daniel', 'oed': 'oxford'}) + self.assertEqual(len(data), 2) + self.assertIsInstance(list(data.keys())[0], text_type) + self.assertIsInstance(list(data.values())[0], text_type) + + # Beginning of the first key + self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) + + self.assertEqual(data[u'webster'].ansible_pos, ('myfile.yml', 2, 26)) + self.assertEqual(data[u'oed'].ansible_pos, ('myfile.yml', 3, 22)) + + def test_parse_list(self): + stream = StringIO(u""" + - a + - b + """) + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, [u'a', u'b']) + self.assertEqual(len(data), 2) + self.assertIsInstance(data[0], text_type) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) + + self.assertEqual(data[0].ansible_pos, ('myfile.yml', 2, 19)) + self.assertEqual(data[1].ansible_pos, ('myfile.yml', 3, 19)) + + def test_parse_short_dict(self): + stream = StringIO(u"""{"foo": "bar"}""") + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, dict(foo=u'bar')) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 1, 1)) + self.assertEqual(data[u'foo'].ansible_pos, ('myfile.yml', 1, 9)) + + stream = StringIO(u"""foo: bar""") + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, dict(foo=u'bar')) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 1, 1)) + self.assertEqual(data[u'foo'].ansible_pos, ('myfile.yml', 1, 6)) + + def test_error_conditions(self): + stream = StringIO(u"""{""") + loader = AnsibleLoader(stream, 'myfile.yml') + self.assertRaises(ParserError, loader.get_single_data) + + def test_tab_error(self): + stream = StringIO(u"""---\nhosts: localhost\nvars:\n foo: bar\n\tblip: baz""") + loader = AnsibleLoader(stream, 'myfile.yml') + self.assertRaises(ScannerError, loader.get_single_data) + + def test_front_matter(self): + stream = StringIO(u"""---\nfoo: bar""") + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, dict(foo=u'bar')) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 1)) + self.assertEqual(data[u'foo'].ansible_pos, ('myfile.yml', 2, 6)) + + # Initial indent (See: #6348) + stream = StringIO(u""" - foo: bar\n baz: qux""") + loader = AnsibleLoader(stream, 'myfile.yml') + data = loader.get_single_data() + self.assertEqual(data, [{u'foo': u'bar', u'baz': u'qux'}]) + + self.assertEqual(data.ansible_pos, ('myfile.yml', 1, 2)) + self.assertEqual(data[0].ansible_pos, ('myfile.yml', 1, 4)) + self.assertEqual(data[0][u'foo'].ansible_pos, ('myfile.yml', 1, 9)) + self.assertEqual(data[0][u'baz'].ansible_pos, ('myfile.yml', 2, 9)) + + +class TestAnsibleLoaderVault(unittest.TestCase, YamlTestUtils): + def setUp(self): + self.vault_password = "hunter42" + vault_secret = TextVaultSecret(self.vault_password) + self.vault_secrets = [('vault_secret', vault_secret), + ('default', vault_secret)] + self.vault = vault.VaultLib(self.vault_secrets) + + @property + def vault_secret(self): + return vault.match_encrypt_secret(self.vault_secrets)[1] + + def test_wrong_password(self): + plaintext = u"Ansible" + bob_password = "this is a different password" + + bobs_secret = TextVaultSecret(bob_password) + bobs_secrets = [('default', bobs_secret)] + + bobs_vault = vault.VaultLib(bobs_secrets) + + ciphertext = bobs_vault.encrypt(plaintext, vault.match_encrypt_secret(bobs_secrets)[1]) + + try: + self.vault.decrypt(ciphertext) + except Exception as e: + self.assertIsInstance(e, errors.AnsibleError) + self.assertEqual(e.message, 'Decryption failed (no vault secrets were found that could decrypt)') + + def _encrypt_plaintext(self, plaintext): + # Construct a yaml repr of a vault by hand + vaulted_var_bytes = self.vault.encrypt(plaintext, self.vault_secret) + + # add yaml tag + vaulted_var = vaulted_var_bytes.decode() + lines = vaulted_var.splitlines() + lines2 = [] + for line in lines: + lines2.append(' %s' % line) + + vaulted_var = '\n'.join(lines2) + tagged_vaulted_var = u"""!vault |\n%s""" % vaulted_var + return tagged_vaulted_var + + def _build_stream(self, yaml_text): + stream = NameStringIO(yaml_text) + stream.name = 'my.yml' + return stream + + def _loader(self, stream): + return AnsibleLoader(stream, vault_secrets=self.vault.secrets) + + def _load_yaml(self, yaml_text, password): + stream = self._build_stream(yaml_text) + loader = self._loader(stream) + + data_from_yaml = loader.get_single_data() + + return data_from_yaml + + def test_dump_load_cycle(self): + avu = AnsibleVaultEncryptedUnicode.from_plaintext('The plaintext for test_dump_load_cycle.', self.vault, self.vault_secret) + self._dump_load_cycle(avu) + + def test_embedded_vault_from_dump(self): + avu = AnsibleVaultEncryptedUnicode.from_plaintext('setec astronomy', self.vault, self.vault_secret) + blip = {'stuff1': [{'a dict key': 24}, + {'shhh-ssh-secrets': avu, + 'nothing to see here': 'move along'}], + 'another key': 24.1} + + blip = ['some string', 'another string', avu] + stream = NameStringIO() + + self._dump_stream(blip, stream, dumper=AnsibleDumper) + + stream.seek(0) + + stream.seek(0) + + loader = self._loader(stream) + + data_from_yaml = loader.get_data() + + stream2 = NameStringIO(u'') + # verify we can dump the object again + self._dump_stream(data_from_yaml, stream2, dumper=AnsibleDumper) + + def test_embedded_vault(self): + plaintext_var = u"""This is the plaintext string.""" + tagged_vaulted_var = self._encrypt_plaintext(plaintext_var) + another_vaulted_var = self._encrypt_plaintext(plaintext_var) + + different_var = u"""A different string that is not the same as the first one.""" + different_vaulted_var = self._encrypt_plaintext(different_var) + + yaml_text = u"""---\nwebster: daniel\noed: oxford\nthe_secret: %s\nanother_secret: %s\ndifferent_secret: %s""" % (tagged_vaulted_var, + another_vaulted_var, + different_vaulted_var) + + data_from_yaml = self._load_yaml(yaml_text, self.vault_password) + vault_string = data_from_yaml['the_secret'] + + self.assertEqual(plaintext_var, data_from_yaml['the_secret']) + + test_dict = {} + test_dict[vault_string] = 'did this work?' + + self.assertEqual(vault_string.data, vault_string) + + # This looks weird and useless, but the object in question has a custom __eq__ + self.assertEqual(vault_string, vault_string) + + another_vault_string = data_from_yaml['another_secret'] + different_vault_string = data_from_yaml['different_secret'] + + self.assertEqual(vault_string, another_vault_string) + self.assertNotEquals(vault_string, different_vault_string) + + # More testing of __eq__/__ne__ + self.assertTrue('some string' != vault_string) + self.assertNotEquals('some string', vault_string) + + # Note this is a compare of the str/unicode of these, they are different types + # so we want to test self == other, and other == self etc + self.assertEqual(plaintext_var, vault_string) + self.assertEqual(vault_string, plaintext_var) + self.assertFalse(plaintext_var != vault_string) + self.assertFalse(vault_string != plaintext_var) + + +class TestAnsibleLoaderPlay(unittest.TestCase): + + def setUp(self): + stream = NameStringIO(u""" + - hosts: localhost + vars: + number: 1 + string: Ansible + utf8_string: Cafè Eñyei + dictionary: + webster: daniel + oed: oxford + list: + - a + - b + - 1 + - 2 + tasks: + - name: Test case + ping: + data: "{{ utf8_string }}" + + - name: Test 2 + ping: + data: "Cafè Eñyei" + + - name: Test 3 + command: "printf 'Cafè Eñyei\\n'" + """) + self.play_filename = '/path/to/myplay.yml' + stream.name = self.play_filename + self.loader = AnsibleLoader(stream) + self.data = self.loader.get_single_data() + + def tearDown(self): + pass + + def test_data_complete(self): + self.assertEqual(len(self.data), 1) + self.assertIsInstance(self.data, list) + self.assertEqual(frozenset(self.data[0].keys()), frozenset((u'hosts', u'vars', u'tasks'))) + + self.assertEqual(self.data[0][u'hosts'], u'localhost') + + self.assertEqual(self.data[0][u'vars'][u'number'], 1) + self.assertEqual(self.data[0][u'vars'][u'string'], u'Ansible') + self.assertEqual(self.data[0][u'vars'][u'utf8_string'], u'Cafè Eñyei') + self.assertEqual(self.data[0][u'vars'][u'dictionary'], { + u'webster': u'daniel', + u'oed': u'oxford' + }) + self.assertEqual(self.data[0][u'vars'][u'list'], [u'a', u'b', 1, 2]) + + self.assertEqual(self.data[0][u'tasks'], [ + {u'name': u'Test case', u'ping': {u'data': u'{{ utf8_string }}'}}, + {u'name': u'Test 2', u'ping': {u'data': u'Cafè Eñyei'}}, + {u'name': u'Test 3', u'command': u'printf \'Cafè Eñyei\n\''}, + ]) + + def walk(self, data): + # Make sure there's no str in the data + self.assertNotIsInstance(data, binary_type) + + # Descend into various container types + if isinstance(data, text_type): + # strings are a sequence so we have to be explicit here + return + elif isinstance(data, (Sequence, Set)): + for element in data: + self.walk(element) + elif isinstance(data, Mapping): + for k, v in data.items(): + self.walk(k) + self.walk(v) + + # Scalars were all checked so we're good to go + return + + def test_no_str_in_data(self): + # Checks that no strings are str type + self.walk(self.data) + + def check_vars(self): + # Numbers don't have line/col information yet + # self.assertEqual(self.data[0][u'vars'][u'number'].ansible_pos, (self.play_filename, 4, 21)) + + self.assertEqual(self.data[0][u'vars'][u'string'].ansible_pos, (self.play_filename, 5, 29)) + self.assertEqual(self.data[0][u'vars'][u'utf8_string'].ansible_pos, (self.play_filename, 6, 34)) + + self.assertEqual(self.data[0][u'vars'][u'dictionary'].ansible_pos, (self.play_filename, 8, 23)) + self.assertEqual(self.data[0][u'vars'][u'dictionary'][u'webster'].ansible_pos, (self.play_filename, 8, 32)) + self.assertEqual(self.data[0][u'vars'][u'dictionary'][u'oed'].ansible_pos, (self.play_filename, 9, 28)) + + self.assertEqual(self.data[0][u'vars'][u'list'].ansible_pos, (self.play_filename, 11, 23)) + self.assertEqual(self.data[0][u'vars'][u'list'][0].ansible_pos, (self.play_filename, 11, 25)) + self.assertEqual(self.data[0][u'vars'][u'list'][1].ansible_pos, (self.play_filename, 12, 25)) + # Numbers don't have line/col info yet + # self.assertEqual(self.data[0][u'vars'][u'list'][2].ansible_pos, (self.play_filename, 13, 25)) + # self.assertEqual(self.data[0][u'vars'][u'list'][3].ansible_pos, (self.play_filename, 14, 25)) + + def check_tasks(self): + # + # First Task + # + self.assertEqual(self.data[0][u'tasks'][0].ansible_pos, (self.play_filename, 16, 23)) + self.assertEqual(self.data[0][u'tasks'][0][u'name'].ansible_pos, (self.play_filename, 16, 29)) + self.assertEqual(self.data[0][u'tasks'][0][u'ping'].ansible_pos, (self.play_filename, 18, 25)) + self.assertEqual(self.data[0][u'tasks'][0][u'ping'][u'data'].ansible_pos, (self.play_filename, 18, 31)) + + # + # Second Task + # + self.assertEqual(self.data[0][u'tasks'][1].ansible_pos, (self.play_filename, 20, 23)) + self.assertEqual(self.data[0][u'tasks'][1][u'name'].ansible_pos, (self.play_filename, 20, 29)) + self.assertEqual(self.data[0][u'tasks'][1][u'ping'].ansible_pos, (self.play_filename, 22, 25)) + self.assertEqual(self.data[0][u'tasks'][1][u'ping'][u'data'].ansible_pos, (self.play_filename, 22, 31)) + + # + # Third Task + # + self.assertEqual(self.data[0][u'tasks'][2].ansible_pos, (self.play_filename, 24, 23)) + self.assertEqual(self.data[0][u'tasks'][2][u'name'].ansible_pos, (self.play_filename, 24, 29)) + self.assertEqual(self.data[0][u'tasks'][2][u'command'].ansible_pos, (self.play_filename, 25, 32)) + + def test_line_numbers(self): + # Check the line/column numbers are correct + # Note: Remember, currently dicts begin at the start of their first entry + self.assertEqual(self.data[0].ansible_pos, (self.play_filename, 2, 19)) + self.assertEqual(self.data[0][u'hosts'].ansible_pos, (self.play_filename, 2, 26)) + self.assertEqual(self.data[0][u'vars'].ansible_pos, (self.play_filename, 4, 21)) + + self.check_vars() + + self.assertEqual(self.data[0][u'tasks'].ansible_pos, (self.play_filename, 16, 21)) + + self.check_tasks() diff --git a/test/units/parsing/yaml/test_objects.py b/test/units/parsing/yaml/test_objects.py new file mode 100644 index 00000000..d4529eed --- /dev/null +++ b/test/units/parsing/yaml/test_objects.py @@ -0,0 +1,164 @@ +# This file is part of Ansible +# -*- coding: utf-8 -*- +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright 2016, Adrian Likins <alikins@redhat.com> + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from units.compat import unittest + +from ansible.errors import AnsibleError + +from ansible.module_utils._text import to_native + +from ansible.parsing import vault +from ansible.parsing.yaml.loader import AnsibleLoader + +# module under test +from ansible.parsing.yaml import objects + +from units.mock.yaml_helper import YamlTestUtils +from units.mock.vault_helper import TextVaultSecret + + +class TestAnsibleVaultUnicodeNoVault(unittest.TestCase, YamlTestUtils): + def test_empty_init(self): + self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode) + + def test_empty_string_init(self): + seq = ''.encode('utf8') + self.assert_values(seq) + + def test_empty_byte_string_init(self): + seq = b'' + self.assert_values(seq) + + def _assert_values(self, avu, seq): + self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode) + self.assertTrue(avu.vault is None) + # AnsibleVaultEncryptedUnicode without a vault should never == any string + self.assertNotEquals(avu, seq) + + def assert_values(self, seq): + avu = objects.AnsibleVaultEncryptedUnicode(seq) + self._assert_values(avu, seq) + + def test_single_char(self): + seq = 'a'.encode('utf8') + self.assert_values(seq) + + def test_string(self): + seq = 'some letters' + self.assert_values(seq) + + def test_byte_string(self): + seq = 'some letters'.encode('utf8') + self.assert_values(seq) + + +class TestAnsibleVaultEncryptedUnicode(unittest.TestCase, YamlTestUtils): + def setUp(self): + self.good_vault_password = "hunter42" + good_vault_secret = TextVaultSecret(self.good_vault_password) + self.good_vault_secrets = [('good_vault_password', good_vault_secret)] + self.good_vault = vault.VaultLib(self.good_vault_secrets) + + # TODO: make this use two vault secret identities instead of two vaultSecrets + self.wrong_vault_password = 'not-hunter42' + wrong_vault_secret = TextVaultSecret(self.wrong_vault_password) + self.wrong_vault_secrets = [('wrong_vault_password', wrong_vault_secret)] + self.wrong_vault = vault.VaultLib(self.wrong_vault_secrets) + + self.vault = self.good_vault + self.vault_secrets = self.good_vault_secrets + + def _loader(self, stream): + return AnsibleLoader(stream, vault_secrets=self.vault_secrets) + + def test_dump_load_cycle(self): + aveu = self._from_plaintext('the test string for TestAnsibleVaultEncryptedUnicode.test_dump_load_cycle') + self._dump_load_cycle(aveu) + + def assert_values(self, avu, seq): + self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode) + + self.assertEqual(avu, seq) + self.assertTrue(avu.vault is self.vault) + self.assertIsInstance(avu.vault, vault.VaultLib) + + def _from_plaintext(self, seq): + id_secret = vault.match_encrypt_secret(self.good_vault_secrets) + return objects.AnsibleVaultEncryptedUnicode.from_plaintext(seq, vault=self.vault, secret=id_secret[1]) + + def _from_ciphertext(self, ciphertext): + avu = objects.AnsibleVaultEncryptedUnicode(ciphertext) + avu.vault = self.vault + return avu + + def test_empty_init(self): + self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode) + + def test_empty_string_init_from_plaintext(self): + seq = '' + avu = self._from_plaintext(seq) + self.assert_values(avu, seq) + + def test_empty_unicode_init_from_plaintext(self): + seq = u'' + avu = self._from_plaintext(seq) + self.assert_values(avu, seq) + + def test_string_from_plaintext(self): + seq = 'some letters' + avu = self._from_plaintext(seq) + self.assert_values(avu, seq) + + def test_unicode_from_plaintext(self): + seq = u'some letters' + avu = self._from_plaintext(seq) + self.assert_values(avu, seq) + + def test_unicode_from_plaintext_encode(self): + seq = u'some text here' + avu = self._from_plaintext(seq) + b_avu = avu.encode('utf-8', 'strict') + self.assertIsInstance(avu, objects.AnsibleVaultEncryptedUnicode) + self.assertEqual(b_avu, seq.encode('utf-8', 'strict')) + self.assertTrue(avu.vault is self.vault) + self.assertIsInstance(avu.vault, vault.VaultLib) + + # TODO/FIXME: make sure bad password fails differently than 'thats not encrypted' + def test_empty_string_wrong_password(self): + seq = '' + self.vault = self.wrong_vault + avu = self._from_plaintext(seq) + + def compare(avu, seq): + return avu == seq + + self.assertRaises(AnsibleError, compare, avu, seq) + + def test_vaulted_utf8_value_37258(self): + seq = u"aöffü" + avu = self._from_plaintext(seq) + self.assert_values(avu, seq) + + def test_str_vaulted_utf8_value_37258(self): + seq = u"aöffü" + avu = self._from_plaintext(seq) + assert str(avu) == to_native(seq) |