diff options
Diffstat (limited to '')
40 files changed, 5481 insertions, 0 deletions
diff --git a/test/units/plugins/__init__.py b/test/units/plugins/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/__init__.py diff --git a/test/units/plugins/action/__init__.py b/test/units/plugins/action/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/action/__init__.py diff --git a/test/units/plugins/action/test_action.py b/test/units/plugins/action/test_action.py new file mode 100644 index 00000000..12488019 --- /dev/null +++ b/test/units/plugins/action/test_action.py @@ -0,0 +1,683 @@ +# -*- coding: utf-8 -*- +# (c) 2015, Florian Apolloner <florian@apolloner.eu> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re + +from ansible import constants as C +from units.compat import unittest +from units.compat.mock import patch, MagicMock, mock_open + +from ansible.errors import AnsibleError +from ansible.module_utils.six import text_type +from ansible.module_utils.six.moves import shlex_quote, builtins +from ansible.module_utils._text import to_bytes +from ansible.playbook.play_context import PlayContext +from ansible.plugins.action import ActionBase +from ansible.template import Templar +from ansible.vars.clean import clean_facts + +from units.mock.loader import DictDataLoader + + +python_module_replacers = br""" +#!/usr/bin/python + +#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>" +#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>" +#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>" + +test = u'Toshio \u304f\u3089\u3068\u307f' +from ansible.module_utils.basic import * +""" + +powershell_module_replacers = b""" +WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>" +# POWERSHELL_COMMON +""" + + +def _action_base(): + fake_loader = DictDataLoader({ + }) + mock_module_loader = MagicMock() + mock_shared_loader_obj = MagicMock() + mock_shared_loader_obj.module_loader = mock_module_loader + mock_connection_loader = MagicMock() + + mock_shared_loader_obj.connection_loader = mock_connection_loader + mock_connection = MagicMock() + + play_context = MagicMock() + + action_base = DerivedActionBase(task=None, + connection=mock_connection, + play_context=play_context, + loader=fake_loader, + templar=None, + shared_loader_obj=mock_shared_loader_obj) + return action_base + + +class DerivedActionBase(ActionBase): + TRANSFERS_FILES = False + + def run(self, tmp=None, task_vars=None): + # We're not testing the plugin run() method, just the helper + # methods ActionBase defines + return super(DerivedActionBase, self).run(tmp=tmp, task_vars=task_vars) + + +class TestActionBase(unittest.TestCase): + + def test_action_base_run(self): + mock_task = MagicMock() + mock_task.action = "foo" + mock_task.args = dict(a=1, b=2, c=3) + + mock_connection = MagicMock() + + play_context = PlayContext() + + mock_task.async_val = None + action_base = DerivedActionBase(mock_task, mock_connection, play_context, None, None, None) + results = action_base.run() + self.assertEqual(results, dict()) + + mock_task.async_val = 0 + action_base = DerivedActionBase(mock_task, mock_connection, play_context, None, None, None) + results = action_base.run() + self.assertEqual(results, {}) + + def test_action_base__configure_module(self): + fake_loader = DictDataLoader({ + }) + + # create our fake task + mock_task = MagicMock() + mock_task.action = "copy" + mock_task.async_val = 0 + mock_task.delegate_to = None + + # create a mock connection, so we don't actually try and connect to things + mock_connection = MagicMock() + + # create a mock shared loader object + def mock_find_plugin_with_context(name, options, collection_list=None): + mockctx = MagicMock() + if name == 'badmodule': + mockctx.resolved = False + mockctx.plugin_resolved_path = None + elif '.ps1' in options: + mockctx.resolved = True + mockctx.plugin_resolved_path = '/fake/path/to/%s.ps1' % name + else: + mockctx.resolved = True + mockctx.plugin_resolved_path = '/fake/path/to/%s' % name + return mockctx + + mock_module_loader = MagicMock() + mock_module_loader.find_plugin_with_context.side_effect = mock_find_plugin_with_context + mock_shared_obj_loader = MagicMock() + mock_shared_obj_loader.module_loader = mock_module_loader + + # we're using a real play context here + play_context = PlayContext() + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=fake_loader, + templar=Templar(loader=fake_loader), + shared_loader_obj=mock_shared_obj_loader, + ) + + # test python module formatting + with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))): + with patch.object(os, 'rename'): + mock_task.args = dict(a=1, foo='fö〩') + mock_connection.module_implementation_preferences = ('',) + (style, shebang, data, path) = action_base._configure_module(mock_task.action, mock_task.args, + task_vars=dict(ansible_python_interpreter='/usr/bin/python')) + self.assertEqual(style, "new") + self.assertEqual(shebang, u"#!/usr/bin/python") + + # test module not found + self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args, {}) + + # test powershell module formatting + with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))): + mock_task.action = 'win_copy' + mock_task.args = dict(b=2) + mock_connection.module_implementation_preferences = ('.ps1',) + (style, shebang, data, path) = action_base._configure_module('stat', mock_task.args, {}) + self.assertEqual(style, "new") + self.assertEqual(shebang, u'#!powershell') + + # test module not found + self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args, {}) + + def test_action_base__compute_environment_string(self): + fake_loader = DictDataLoader({ + }) + + # create our fake task + mock_task = MagicMock() + mock_task.action = "copy" + mock_task.args = dict(a=1) + + # create a mock connection, so we don't actually try and connect to things + def env_prefix(**args): + return ' '.join(['%s=%s' % (k, shlex_quote(text_type(v))) for k, v in args.items()]) + mock_connection = MagicMock() + mock_connection._shell.env_prefix.side_effect = env_prefix + + # we're using a real play context here + play_context = PlayContext() + + # and we're using a real templar here too + templar = Templar(loader=fake_loader) + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=fake_loader, + templar=templar, + shared_loader_obj=None, + ) + + # test standard environment setup + mock_task.environment = [dict(FOO='foo'), None] + env_string = action_base._compute_environment_string() + self.assertEqual(env_string, "FOO=foo") + + # test where environment is not a list + mock_task.environment = dict(FOO='foo') + env_string = action_base._compute_environment_string() + self.assertEqual(env_string, "FOO=foo") + + # test environment with a variable in it + templar.available_variables = dict(the_var='bar') + mock_task.environment = [dict(FOO='{{the_var}}')] + env_string = action_base._compute_environment_string() + self.assertEqual(env_string, "FOO=bar") + + # test with a bad environment set + mock_task.environment = dict(FOO='foo') + mock_task.environment = ['hi there'] + self.assertRaises(AnsibleError, action_base._compute_environment_string) + + def test_action_base__early_needs_tmp_path(self): + # create our fake task + mock_task = MagicMock() + + # create a mock connection, so we don't actually try and connect to things + mock_connection = MagicMock() + + # we're using a real play context here + play_context = PlayContext() + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=None, + templar=None, + shared_loader_obj=None, + ) + + self.assertFalse(action_base._early_needs_tmp_path()) + + action_base.TRANSFERS_FILES = True + self.assertTrue(action_base._early_needs_tmp_path()) + + def test_action_base__make_tmp_path(self): + # create our fake task + mock_task = MagicMock() + + def get_shell_opt(opt): + + ret = None + if opt == 'admin_users': + ret = ['root', 'toor', 'Administrator'] + elif opt == 'remote_tmp': + ret = '~/.ansible/tmp' + + return ret + + # create a mock connection, so we don't actually try and connect to things + mock_connection = MagicMock() + mock_connection.transport = 'ssh' + mock_connection._shell.mkdtemp.return_value = 'mkdir command' + mock_connection._shell.join_path.side_effect = os.path.join + mock_connection._shell.get_option = get_shell_opt + mock_connection._shell.HOMES_RE = re.compile(r'(\'|\")?(~|\$HOME)(.*)') + + # we're using a real play context here + play_context = PlayContext() + play_context.become = True + play_context.become_user = 'foo' + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=None, + templar=None, + shared_loader_obj=None, + ) + + action_base._low_level_execute_command = MagicMock() + action_base._low_level_execute_command.return_value = dict(rc=0, stdout='/some/path') + self.assertEqual(action_base._make_tmp_path('root'), '/some/path/') + + # empty path fails + action_base._low_level_execute_command.return_value = dict(rc=0, stdout='') + self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root') + + # authentication failure + action_base._low_level_execute_command.return_value = dict(rc=5, stdout='') + self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root') + + # ssh error + action_base._low_level_execute_command.return_value = dict(rc=255, stdout='', stderr='') + self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root') + play_context.verbosity = 5 + self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root') + + # general error + action_base._low_level_execute_command.return_value = dict(rc=1, stdout='some stuff here', stderr='') + self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root') + action_base._low_level_execute_command.return_value = dict(rc=1, stdout='some stuff here', stderr='No space left on device') + self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root') + + def test_action_base__remove_tmp_path(self): + # create our fake task + mock_task = MagicMock() + + # create a mock connection, so we don't actually try and connect to things + mock_connection = MagicMock() + mock_connection._shell.remove.return_value = 'rm some stuff' + + # we're using a real play context here + play_context = PlayContext() + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=None, + templar=None, + shared_loader_obj=None, + ) + + action_base._low_level_execute_command = MagicMock() + # these don't really return anything or raise errors, so + # we're pretty much calling these for coverage right now + action_base._remove_tmp_path('/bad/path/dont/remove') + action_base._remove_tmp_path('/good/path/to/ansible-tmp-thing') + + @patch('os.unlink') + @patch('os.fdopen') + @patch('tempfile.mkstemp') + def test_action_base__transfer_data(self, mock_mkstemp, mock_fdopen, mock_unlink): + # create our fake task + mock_task = MagicMock() + + # create a mock connection, so we don't actually try and connect to things + mock_connection = MagicMock() + mock_connection.put_file.return_value = None + + # we're using a real play context here + play_context = PlayContext() + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=None, + templar=None, + shared_loader_obj=None, + ) + + mock_afd = MagicMock() + mock_afile = MagicMock() + mock_mkstemp.return_value = (mock_afd, mock_afile) + + mock_unlink.return_value = None + + mock_afo = MagicMock() + mock_afo.write.return_value = None + mock_afo.flush.return_value = None + mock_afo.close.return_value = None + mock_fdopen.return_value = mock_afo + + self.assertEqual(action_base._transfer_data('/path/to/remote/file', 'some data'), '/path/to/remote/file') + self.assertEqual(action_base._transfer_data('/path/to/remote/file', 'some mixed data: fö〩'), '/path/to/remote/file') + self.assertEqual(action_base._transfer_data('/path/to/remote/file', dict(some_key='some value')), '/path/to/remote/file') + self.assertEqual(action_base._transfer_data('/path/to/remote/file', dict(some_key='fö〩')), '/path/to/remote/file') + + mock_afo.write.side_effect = Exception() + self.assertRaises(AnsibleError, action_base._transfer_data, '/path/to/remote/file', '') + + def test_action_base__execute_remote_stat(self): + # create our fake task + mock_task = MagicMock() + + # create a mock connection, so we don't actually try and connect to things + mock_connection = MagicMock() + + # we're using a real play context here + play_context = PlayContext() + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=None, + templar=None, + shared_loader_obj=None, + ) + + action_base._execute_module = MagicMock() + + # test normal case + action_base._execute_module.return_value = dict(stat=dict(checksum='1111111111111111111111111111111111', exists=True)) + res = action_base._execute_remote_stat(path='/path/to/file', all_vars=dict(), follow=False) + self.assertEqual(res['checksum'], '1111111111111111111111111111111111') + + # test does not exist + action_base._execute_module.return_value = dict(stat=dict(exists=False)) + res = action_base._execute_remote_stat(path='/path/to/file', all_vars=dict(), follow=False) + self.assertFalse(res['exists']) + self.assertEqual(res['checksum'], '1') + + # test no checksum in result from _execute_module + action_base._execute_module.return_value = dict(stat=dict(exists=True)) + res = action_base._execute_remote_stat(path='/path/to/file', all_vars=dict(), follow=False) + self.assertTrue(res['exists']) + self.assertEqual(res['checksum'], '') + + # test stat call failed + action_base._execute_module.return_value = dict(failed=True, msg="because I said so") + self.assertRaises(AnsibleError, action_base._execute_remote_stat, path='/path/to/file', all_vars=dict(), follow=False) + + def test_action_base__execute_module(self): + # create our fake task + mock_task = MagicMock() + mock_task.action = 'copy' + mock_task.args = dict(a=1, b=2, c=3) + + # create a mock connection, so we don't actually try and connect to things + def build_module_command(env_string, shebang, cmd, arg_path=None): + to_run = [env_string, cmd] + if arg_path: + to_run.append(arg_path) + return " ".join(to_run) + + def get_option(option): + return {'admin_users': ['root', 'toor']}.get(option) + + mock_connection = MagicMock() + mock_connection.build_module_command.side_effect = build_module_command + mock_connection.socket_path = None + mock_connection._shell.get_remote_filename.return_value = 'copy.py' + mock_connection._shell.join_path.side_effect = os.path.join + mock_connection._shell.tmpdir = '/var/tmp/mytempdir' + mock_connection._shell.get_option = get_option + + # we're using a real play context here + play_context = PlayContext() + + # our test class + action_base = DerivedActionBase( + task=mock_task, + connection=mock_connection, + play_context=play_context, + loader=None, + templar=None, + shared_loader_obj=None, + ) + + # fake a lot of methods as we test those elsewhere + action_base._configure_module = MagicMock() + action_base._supports_check_mode = MagicMock() + action_base._is_pipelining_enabled = MagicMock() + action_base._make_tmp_path = MagicMock() + action_base._transfer_data = MagicMock() + action_base._compute_environment_string = MagicMock() + action_base._low_level_execute_command = MagicMock() + action_base._fixup_perms2 = MagicMock() + + action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', 'path') + action_base._is_pipelining_enabled.return_value = False + action_base._compute_environment_string.return_value = '' + action_base._connection.has_pipelining = False + action_base._make_tmp_path.return_value = '/the/tmp/path' + action_base._low_level_execute_command.return_value = dict(stdout='{"rc": 0, "stdout": "ok"}') + self.assertEqual(action_base._execute_module(module_name=None, module_args=None), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok'])) + self.assertEqual( + action_base._execute_module( + module_name='foo', + module_args=dict(z=9, y=8, x=7), + task_vars=dict(a=1) + ), + dict( + _ansible_parsed=True, + rc=0, + stdout="ok", + stdout_lines=['ok'], + ) + ) + + # test with needing/removing a remote tmp path + action_base._configure_module.return_value = ('old', '#!/usr/bin/python', 'this is the module data', 'path') + action_base._is_pipelining_enabled.return_value = False + action_base._make_tmp_path.return_value = '/the/tmp/path' + self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok'])) + + action_base._configure_module.return_value = ('non_native_want_json', '#!/usr/bin/python', 'this is the module data', 'path') + self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok'])) + + play_context.become = True + play_context.become_user = 'foo' + self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok'])) + + # test an invalid shebang return + action_base._configure_module.return_value = ('new', '', 'this is the module data', 'path') + action_base._is_pipelining_enabled.return_value = False + action_base._make_tmp_path.return_value = '/the/tmp/path' + self.assertRaises(AnsibleError, action_base._execute_module) + + # test with check mode enabled, once with support for check + # mode and once with support disabled to raise an error + play_context.check_mode = True + action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', 'path') + self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok'])) + action_base._supports_check_mode = False + self.assertRaises(AnsibleError, action_base._execute_module) + + def test_action_base_sudo_only_if_user_differs(self): + fake_loader = MagicMock() + fake_loader.get_basedir.return_value = os.getcwd() + play_context = PlayContext() + + action_base = DerivedActionBase(None, None, play_context, fake_loader, None, None) + action_base.get_become_option = MagicMock(return_value='root') + action_base._get_remote_user = MagicMock(return_value='root') + + action_base._connection = MagicMock(exec_command=MagicMock(return_value=(0, '', ''))) + + action_base._connection._shell = shell = MagicMock(append_command=MagicMock(return_value=('JOINED CMD'))) + + action_base._connection.become = become = MagicMock() + become.build_become_command.return_value = 'foo' + + action_base._low_level_execute_command('ECHO', sudoable=True) + become.build_become_command.assert_not_called() + + action_base._get_remote_user.return_value = 'apo' + action_base._low_level_execute_command('ECHO', sudoable=True, executable='/bin/csh') + become.build_become_command.assert_called_once_with("ECHO", shell) + + become.build_become_command.reset_mock() + + with patch.object(C, 'BECOME_ALLOW_SAME_USER', new=True): + action_base._get_remote_user.return_value = 'root' + action_base._low_level_execute_command('ECHO SAME', sudoable=True) + become.build_become_command.assert_called_once_with("ECHO SAME", shell) + + def test__remote_expand_user_relative_pathing(self): + action_base = _action_base() + action_base._play_context.remote_addr = 'bar' + action_base._low_level_execute_command = MagicMock(return_value={'stdout': b'../home/user'}) + action_base._connection._shell.join_path.return_value = '../home/user/foo' + with self.assertRaises(AnsibleError) as cm: + action_base._remote_expand_user('~/foo') + self.assertEqual( + cm.exception.message, + "'bar' returned an invalid relative home directory path containing '..'" + ) + + +class TestActionBaseCleanReturnedData(unittest.TestCase): + def test(self): + + fake_loader = DictDataLoader({ + }) + mock_module_loader = MagicMock() + mock_shared_loader_obj = MagicMock() + mock_shared_loader_obj.module_loader = mock_module_loader + connection_loader_paths = ['/tmp/asdfadf', '/usr/lib64/whatever', + 'dfadfasf', + 'foo.py', + '.*', + # FIXME: a path with parans breaks the regex + # '(.*)', + '/path/to/ansible/lib/ansible/plugins/connection/custom_connection.py', + '/path/to/ansible/lib/ansible/plugins/connection/ssh.py'] + + def fake_all(path_only=None): + for path in connection_loader_paths: + yield path + + mock_connection_loader = MagicMock() + mock_connection_loader.all = fake_all + + mock_shared_loader_obj.connection_loader = mock_connection_loader + mock_connection = MagicMock() + # mock_connection._shell.env_prefix.side_effect = env_prefix + + # action_base = DerivedActionBase(mock_task, mock_connection, play_context, None, None, None) + action_base = DerivedActionBase(task=None, + connection=mock_connection, + play_context=None, + loader=fake_loader, + templar=None, + shared_loader_obj=mock_shared_loader_obj) + data = {'ansible_playbook_python': '/usr/bin/python', + # 'ansible_rsync_path': '/usr/bin/rsync', + 'ansible_python_interpreter': '/usr/bin/python', + 'ansible_ssh_some_var': 'whatever', + 'ansible_ssh_host_key_somehost': 'some key here', + 'some_other_var': 'foo bar'} + data = clean_facts(data) + self.assertNotIn('ansible_playbook_python', data) + self.assertNotIn('ansible_python_interpreter', data) + self.assertIn('ansible_ssh_host_key_somehost', data) + self.assertIn('some_other_var', data) + + +class TestActionBaseParseReturnedData(unittest.TestCase): + + def test_fail_no_json(self): + action_base = _action_base() + rc = 0 + stdout = 'foo\nbar\n' + err = 'oopsy' + returned_data = {'rc': rc, + 'stdout': stdout, + 'stdout_lines': stdout.splitlines(), + 'stderr': err} + res = action_base._parse_returned_data(returned_data) + self.assertFalse(res['_ansible_parsed']) + self.assertTrue(res['failed']) + self.assertEqual(res['module_stderr'], err) + + def test_json_empty(self): + action_base = _action_base() + rc = 0 + stdout = '{}\n' + err = '' + returned_data = {'rc': rc, + 'stdout': stdout, + 'stdout_lines': stdout.splitlines(), + 'stderr': err} + res = action_base._parse_returned_data(returned_data) + del res['_ansible_parsed'] # we always have _ansible_parsed + self.assertEqual(len(res), 0) + self.assertFalse(res) + + def test_json_facts(self): + action_base = _action_base() + rc = 0 + stdout = '{"ansible_facts": {"foo": "bar", "ansible_blip": "blip_value"}}\n' + err = '' + + returned_data = {'rc': rc, + 'stdout': stdout, + 'stdout_lines': stdout.splitlines(), + 'stderr': err} + res = action_base._parse_returned_data(returned_data) + self.assertTrue(res['ansible_facts']) + self.assertIn('ansible_blip', res['ansible_facts']) + # TODO: Should this be an AnsibleUnsafe? + # self.assertIsInstance(res['ansible_facts'], AnsibleUnsafe) + + def test_json_facts_add_host(self): + action_base = _action_base() + rc = 0 + stdout = '''{"ansible_facts": {"foo": "bar", "ansible_blip": "blip_value"}, + "add_host": {"host_vars": {"some_key": ["whatever the add_host object is"]} + } + }\n''' + err = '' + + returned_data = {'rc': rc, + 'stdout': stdout, + 'stdout_lines': stdout.splitlines(), + 'stderr': err} + res = action_base._parse_returned_data(returned_data) + self.assertTrue(res['ansible_facts']) + self.assertIn('ansible_blip', res['ansible_facts']) + self.assertIn('add_host', res) + # TODO: Should this be an AnsibleUnsafe? + # self.assertIsInstance(res['ansible_facts'], AnsibleUnsafe) diff --git a/test/units/plugins/action/test_gather_facts.py b/test/units/plugins/action/test_gather_facts.py new file mode 100644 index 00000000..e15edd39 --- /dev/null +++ b/test/units/plugins/action/test_gather_facts.py @@ -0,0 +1,87 @@ +# (c) 2016, Saran Ahluwalia <ahlusar.ahluwalia@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from units.compat import unittest +from units.compat.mock import MagicMock, patch + +from ansible import constants as C +from ansible.plugins.action.gather_facts import ActionModule +from ansible.playbook.task import Task +from ansible.template import Templar +import ansible.executor.module_common as module_common + +from units.mock.loader import DictDataLoader + + +class TestNetworkFacts(unittest.TestCase): + task = MagicMock(Task) + play_context = MagicMock() + play_context.check_mode = False + connection = MagicMock() + fake_loader = DictDataLoader({ + }) + templar = Templar(loader=fake_loader) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_network_gather_facts(self): + self.task_vars = {'ansible_network_os': 'ios'} + self.task.action = 'gather_facts' + self.task.async_val = False + self.task._ansible_internal_redirect_list = [] + self.task.args = {'gather_subset': 'min'} + self.task.module_defaults = [{'ios_facts': {'gather_subset': 'min'}}] + + plugin = ActionModule(self.task, self.connection, self.play_context, loader=None, templar=self.templar, shared_loader_obj=None) + plugin._execute_module = MagicMock() + + res = plugin.run(task_vars=self.task_vars) + self.assertEqual(res['ansible_facts']['_ansible_facts_gathered'], True) + + mod_args = plugin._get_module_args('ios_facts', task_vars=self.task_vars) + self.assertEqual(mod_args['gather_subset'], 'min') + + facts_modules = C.config.get_config_value('FACTS_MODULES', variables=self.task_vars) + self.assertEqual(facts_modules, ['ansible.legacy.ios_facts']) + + @patch.object(module_common, '_get_collection_metadata', return_value={}) + def test_network_gather_facts_fqcn(self, mock_collection_metadata): + self.fqcn_task_vars = {'ansible_network_os': 'cisco.ios.ios'} + self.task.action = 'gather_facts' + self.task._ansible_internal_redirect_list = ['cisco.ios.ios_facts'] + self.task.async_val = False + self.task.args = {'gather_subset': 'min'} + self.task.module_defaults = [{'cisco.ios.ios_facts': {'gather_subset': 'min'}}] + + plugin = ActionModule(self.task, self.connection, self.play_context, loader=None, templar=self.templar, shared_loader_obj=None) + plugin._execute_module = MagicMock() + + res = plugin.run(task_vars=self.fqcn_task_vars) + self.assertEqual(res['ansible_facts']['_ansible_facts_gathered'], True) + + mod_args = plugin._get_module_args('cisco.ios.ios_facts', task_vars=self.fqcn_task_vars) + self.assertEqual(mod_args['gather_subset'], 'min') + + facts_modules = C.config.get_config_value('FACTS_MODULES', variables=self.fqcn_task_vars) + self.assertEqual(facts_modules, ['cisco.ios.ios_facts']) diff --git a/test/units/plugins/action/test_raw.py b/test/units/plugins/action/test_raw.py new file mode 100644 index 00000000..a8bde6c1 --- /dev/null +++ b/test/units/plugins/action/test_raw.py @@ -0,0 +1,105 @@ +# (c) 2016, Saran Ahluwalia <ahlusar.ahluwalia@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from ansible.errors import AnsibleActionFail +from units.compat import unittest +from units.compat.mock import MagicMock, Mock +from ansible.plugins.action.raw import ActionModule +from ansible.playbook.task import Task +from ansible.plugins.loader import connection_loader + + +class TestCopyResultExclude(unittest.TestCase): + + def setUp(self): + self.play_context = Mock() + self.play_context.shell = 'sh' + self.connection = connection_loader.get('local', self.play_context, os.devnull) + + def tearDown(self): + pass + + # The current behavior of the raw aciton in regards to executable is currently in question; + # the test_raw_executable_is_not_empty_string verifies the current behavior (whether it is desireed or not. + # Please refer to the following for context: + # Issue: https://github.com/ansible/ansible/issues/16054 + # PR: https://github.com/ansible/ansible/pull/16085 + + def test_raw_executable_is_not_empty_string(self): + + task = MagicMock(Task) + task.async_val = False + + task.args = {'_raw_params': 'Args1'} + self.play_context.check_mode = False + + self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None) + self.mock_am._low_level_execute_command = Mock(return_value={}) + self.mock_am.display = Mock() + self.mock_am._admin_users = ['root', 'toor'] + + self.mock_am.run() + self.mock_am._low_level_execute_command.assert_called_with('Args1', executable=False) + + def test_raw_check_mode_is_True(self): + + task = MagicMock(Task) + task.async_val = False + + task.args = {'_raw_params': 'Args1'} + self.play_context.check_mode = True + + try: + self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None) + except AnsibleActionFail: + pass + + def test_raw_test_environment_is_None(self): + + task = MagicMock(Task) + task.async_val = False + + task.args = {'_raw_params': 'Args1'} + task.environment = None + self.play_context.check_mode = False + + self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None) + self.mock_am._low_level_execute_command = Mock(return_value={}) + self.mock_am.display = Mock() + + self.assertEqual(task.environment, None) + + def test_raw_task_vars_is_not_None(self): + + task = MagicMock(Task) + task.async_val = False + + task.args = {'_raw_params': 'Args1'} + task.environment = None + self.play_context.check_mode = False + + self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None) + self.mock_am._low_level_execute_command = Mock(return_value={}) + self.mock_am.display = Mock() + + self.mock_am.run(task_vars={'a': 'b'}) + self.assertEqual(task.environment, None) diff --git a/test/units/plugins/become/__init__.py b/test/units/plugins/become/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/become/__init__.py diff --git a/test/units/plugins/become/conftest.py b/test/units/plugins/become/conftest.py new file mode 100644 index 00000000..a04a5e2d --- /dev/null +++ b/test/units/plugins/become/conftest.py @@ -0,0 +1,37 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# (c) 2017 Ansible Project +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible.cli.arguments import option_helpers as opt_help +from ansible.utils import context_objects as co + + +@pytest.fixture +def parser(): + parser = opt_help.create_base_parser('testparser') + + opt_help.add_runas_options(parser) + opt_help.add_meta_options(parser) + opt_help.add_runtask_options(parser) + opt_help.add_vault_options(parser) + opt_help.add_async_options(parser) + opt_help.add_connect_options(parser) + opt_help.add_subset_options(parser) + opt_help.add_check_options(parser) + opt_help.add_inventory_options(parser) + + return parser + + +@pytest.fixture +def reset_cli_args(): + co.GlobalCLIArgs._Singleton__instance = None + yield + co.GlobalCLIArgs._Singleton__instance = None diff --git a/test/units/plugins/become/test_su.py b/test/units/plugins/become/test_su.py new file mode 100644 index 00000000..73eb71dd --- /dev/null +++ b/test/units/plugins/become/test_su.py @@ -0,0 +1,40 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# (c) 2020 Ansible Project +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible import context +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import become_loader + + +def test_su(mocker, parser, reset_cli_args): + options = parser.parse_args([]) + context._init_global_context(options) + play_context = PlayContext() + + default_cmd = "/bin/foo" + default_exe = "/bin/bash" + su_exe = 'su' + su_flags = '' + + cmd = play_context.make_become_cmd(cmd=default_cmd, executable=default_exe) + assert cmd == default_cmd + + success = 'BECOME-SUCCESS-.+?' + + play_context.become = True + play_context.become_user = 'foo' + play_context.become_pass = None + play_context.become_method = 'su' + play_context.set_become_plugin(become_loader.get('su')) + play_context.become_flags = su_flags + cmd = play_context.make_become_cmd(cmd=default_cmd, executable=default_exe) + assert (re.match("""%s %s -c '%s -c '"'"'echo %s; %s'"'"''""" % (su_exe, play_context.become_user, default_exe, + success, default_cmd), cmd) is not None) diff --git a/test/units/plugins/become/test_sudo.py b/test/units/plugins/become/test_sudo.py new file mode 100644 index 00000000..ba501296 --- /dev/null +++ b/test/units/plugins/become/test_sudo.py @@ -0,0 +1,45 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# (c) 2020 Ansible Project +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible import context +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import become_loader + + +def test_sudo(mocker, parser, reset_cli_args): + options = parser.parse_args([]) + context._init_global_context(options) + play_context = PlayContext() + + default_cmd = "/bin/foo" + default_exe = "/bin/bash" + sudo_exe = 'sudo' + sudo_flags = '-H -s -n' + + cmd = play_context.make_become_cmd(cmd=default_cmd, executable=default_exe) + assert cmd == default_cmd + + success = 'BECOME-SUCCESS-.+?' + + play_context.become = True + play_context.become_user = 'foo' + play_context.set_become_plugin(become_loader.get('sudo')) + play_context.become_flags = sudo_flags + cmd = play_context.make_become_cmd(cmd=default_cmd, executable=default_exe) + + assert (re.match("""%s %s -u %s %s -c 'echo %s; %s'""" % (sudo_exe, sudo_flags, play_context.become_user, + default_exe, success, default_cmd), cmd) is not None) + + play_context.become_pass = 'testpass' + cmd = play_context.make_become_cmd(cmd=default_cmd, executable=default_exe) + assert (re.match("""%s %s -p "%s" -u %s %s -c 'echo %s; %s'""" % (sudo_exe, sudo_flags.replace('-n', ''), + r"\[sudo via ansible, key=.+?\] password:", play_context.become_user, + default_exe, success, default_cmd), cmd) is not None) diff --git a/test/units/plugins/cache/__init__.py b/test/units/plugins/cache/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/cache/__init__.py diff --git a/test/units/plugins/cache/test_cache.py b/test/units/plugins/cache/test_cache.py new file mode 100644 index 00000000..1f16b806 --- /dev/null +++ b/test/units/plugins/cache/test_cache.py @@ -0,0 +1,167 @@ +# (c) 2012-2015, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from units.compat import unittest, mock +from ansible.errors import AnsibleError +from ansible.plugins.cache import FactCache, CachePluginAdjudicator +from ansible.plugins.cache.base import BaseCacheModule +from ansible.plugins.cache.memory import CacheModule as MemoryCache +from ansible.plugins.loader import cache_loader + +import pytest + + +class TestCachePluginAdjudicator: + # memory plugin cache + cache = CachePluginAdjudicator() + cache['cache_key'] = {'key1': 'value1', 'key2': 'value2'} + cache['cache_key_2'] = {'key': 'value'} + + def test___setitem__(self): + self.cache['new_cache_key'] = {'new_key1': ['new_value1', 'new_value2']} + assert self.cache['new_cache_key'] == {'new_key1': ['new_value1', 'new_value2']} + + def test_inner___setitem__(self): + self.cache['new_cache_key'] = {'new_key1': ['new_value1', 'new_value2']} + self.cache['new_cache_key']['new_key1'][0] = 'updated_value1' + assert self.cache['new_cache_key'] == {'new_key1': ['updated_value1', 'new_value2']} + + def test___contains__(self): + assert 'cache_key' in self.cache + assert 'not_cache_key' not in self.cache + + def test_get(self): + assert self.cache.get('cache_key') == {'key1': 'value1', 'key2': 'value2'} + + def test_get_with_default(self): + assert self.cache.get('foo', 'bar') == 'bar' + + def test_get_without_default(self): + assert self.cache.get('foo') is None + + def test___getitem__(self): + with pytest.raises(KeyError) as err: + self.cache['foo'] + + def test_pop_with_default(self): + assert self.cache.pop('foo', 'bar') == 'bar' + + def test_pop_without_default(self): + with pytest.raises(KeyError) as err: + assert self.cache.pop('foo') + + def test_pop(self): + v = self.cache.pop('cache_key_2') + assert v == {'key': 'value'} + assert 'cache_key_2' not in self.cache + + def test_update(self): + self.cache.update({'cache_key': {'key2': 'updatedvalue'}}) + assert self.cache['cache_key']['key2'] == 'updatedvalue' + + +class TestFactCache(unittest.TestCase): + + def setUp(self): + with mock.patch('ansible.constants.CACHE_PLUGIN', 'memory'): + self.cache = FactCache() + + def test_copy(self): + self.cache['avocado'] = 'fruit' + self.cache['daisy'] = 'flower' + a_copy = self.cache.copy() + self.assertEqual(type(a_copy), dict) + self.assertEqual(a_copy, dict(avocado='fruit', daisy='flower')) + + def test_plugin_load_failure(self): + # See https://github.com/ansible/ansible/issues/18751 + # Note no fact_connection config set, so this will fail + with mock.patch('ansible.constants.CACHE_PLUGIN', 'json'): + self.assertRaisesRegexp(AnsibleError, + "Unable to load the facts cache plugin.*json.*", + FactCache) + + def test_update(self): + self.cache.update({'cache_key': {'key2': 'updatedvalue'}}) + assert self.cache['cache_key']['key2'] == 'updatedvalue' + + def test_update_legacy(self): + self.cache.update('cache_key', {'key2': 'updatedvalue'}) + assert self.cache['cache_key']['key2'] == 'updatedvalue' + + def test_update_legacy_key_exists(self): + self.cache['cache_key'] = {'key': 'value', 'key2': 'value2'} + self.cache.update('cache_key', {'key': 'updatedvalue'}) + assert self.cache['cache_key']['key'] == 'updatedvalue' + assert self.cache['cache_key']['key2'] == 'value2' + + +class TestAbstractClass(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_subclass_error(self): + class CacheModule1(BaseCacheModule): + pass + with self.assertRaises(TypeError): + CacheModule1() # pylint: disable=abstract-class-instantiated + + class CacheModule2(BaseCacheModule): + def get(self, key): + super(CacheModule2, self).get(key) + + with self.assertRaises(TypeError): + CacheModule2() # pylint: disable=abstract-class-instantiated + + def test_subclass_success(self): + class CacheModule3(BaseCacheModule): + def get(self, key): + super(CacheModule3, self).get(key) + + def set(self, key, value): + super(CacheModule3, self).set(key, value) + + def keys(self): + super(CacheModule3, self).keys() + + def contains(self, key): + super(CacheModule3, self).contains(key) + + def delete(self, key): + super(CacheModule3, self).delete(key) + + def flush(self): + super(CacheModule3, self).flush() + + def copy(self): + super(CacheModule3, self).copy() + + self.assertIsInstance(CacheModule3(), CacheModule3) + + def test_memory_cachemodule(self): + self.assertIsInstance(MemoryCache(), MemoryCache) + + def test_memory_cachemodule_with_loader(self): + self.assertIsInstance(cache_loader.get('memory'), MemoryCache) diff --git a/test/units/plugins/callback/__init__.py b/test/units/plugins/callback/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/callback/__init__.py diff --git a/test/units/plugins/callback/test_callback.py b/test/units/plugins/callback/test_callback.py new file mode 100644 index 00000000..0c9a335c --- /dev/null +++ b/test/units/plugins/callback/test_callback.py @@ -0,0 +1,412 @@ +# (c) 2012-2014, Chris Meyers <chris.meyers.fsu@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json +import re +import textwrap +import types + +from units.compat import unittest +from units.compat.mock import MagicMock + +from ansible.plugins.callback import CallbackBase + + +class TestCallback(unittest.TestCase): + # FIXME: This doesn't really test anything... + def test_init(self): + CallbackBase() + + def test_display(self): + display_mock = MagicMock() + display_mock.verbosity = 0 + cb = CallbackBase(display=display_mock) + self.assertIs(cb._display, display_mock) + + def test_display_verbose(self): + display_mock = MagicMock() + display_mock.verbosity = 5 + cb = CallbackBase(display=display_mock) + self.assertIs(cb._display, display_mock) + + # TODO: import callback module so we can patch callback.cli/callback.C + + +class TestCallbackResults(unittest.TestCase): + + def test_get_item(self): + cb = CallbackBase() + results = {'item': 'some_item'} + res = cb._get_item(results) + self.assertEqual(res, 'some_item') + + def test_get_item_no_log(self): + cb = CallbackBase() + results = {'item': 'some_item', '_ansible_no_log': True} + res = cb._get_item(results) + self.assertEqual(res, "(censored due to no_log)") + + results = {'item': 'some_item', '_ansible_no_log': False} + res = cb._get_item(results) + self.assertEqual(res, "some_item") + + def test_get_item_label(self): + cb = CallbackBase() + results = {'item': 'some_item'} + res = cb._get_item_label(results) + self.assertEqual(res, 'some_item') + + def test_get_item_label_no_log(self): + cb = CallbackBase() + results = {'item': 'some_item', '_ansible_no_log': True} + res = cb._get_item_label(results) + self.assertEqual(res, "(censored due to no_log)") + + results = {'item': 'some_item', '_ansible_no_log': False} + res = cb._get_item_label(results) + self.assertEqual(res, "some_item") + + def test_clean_results_debug_task(self): + cb = CallbackBase() + result = {'item': 'some_item', + 'invocation': 'foo --bar whatever [some_json]', + 'a': 'a single a in result note letter a is in invocation', + 'b': 'a single b in result note letter b is not in invocation', + 'changed': True} + + cb._clean_results(result, 'debug') + + # See https://github.com/ansible/ansible/issues/33723 + self.assertTrue('a' in result) + self.assertTrue('b' in result) + self.assertFalse('invocation' in result) + self.assertFalse('changed' in result) + + def test_clean_results_debug_task_no_invocation(self): + cb = CallbackBase() + result = {'item': 'some_item', + 'a': 'a single a in result note letter a is in invocation', + 'b': 'a single b in result note letter b is not in invocation', + 'changed': True} + + cb._clean_results(result, 'debug') + self.assertTrue('a' in result) + self.assertTrue('b' in result) + self.assertFalse('changed' in result) + self.assertFalse('invocation' in result) + + def test_clean_results_debug_task_empty_results(self): + cb = CallbackBase() + result = {} + cb._clean_results(result, 'debug') + self.assertFalse('invocation' in result) + self.assertEqual(len(result), 0) + + def test_clean_results(self): + cb = CallbackBase() + result = {'item': 'some_item', + 'invocation': 'foo --bar whatever [some_json]', + 'a': 'a single a in result note letter a is in invocation', + 'b': 'a single b in result note letter b is not in invocation', + 'changed': True} + + expected_result = result.copy() + cb._clean_results(result, 'ebug') + self.assertEqual(result, expected_result) + + +class TestCallbackDumpResults(object): + def test_internal_keys(self): + cb = CallbackBase() + result = {'item': 'some_item', + '_ansible_some_var': 'SENTINEL', + 'testing_ansible_out': 'should_be_left_in LEFTIN', + 'invocation': 'foo --bar whatever [some_json]', + 'some_dict_key': {'a_sub_dict_for_key': 'baz'}, + 'bad_dict_key': {'_ansible_internal_blah': 'SENTINEL'}, + 'changed': True} + json_out = cb._dump_results(result) + assert '"_ansible_' not in json_out + assert 'SENTINEL' not in json_out + assert 'LEFTIN' in json_out + + def test_exception(self): + cb = CallbackBase() + result = {'item': 'some_item LEFTIN', + 'exception': ['frame1', 'SENTINEL']} + json_out = cb._dump_results(result) + assert 'SENTINEL' not in json_out + assert 'exception' not in json_out + assert 'LEFTIN' in json_out + + def test_verbose(self): + cb = CallbackBase() + result = {'item': 'some_item LEFTIN', + '_ansible_verbose_always': 'chicane'} + json_out = cb._dump_results(result) + assert 'SENTINEL' not in json_out + assert 'LEFTIN' in json_out + + def test_diff(self): + cb = CallbackBase() + result = {'item': 'some_item LEFTIN', + 'diff': ['remove stuff', 'added LEFTIN'], + '_ansible_verbose_always': 'chicane'} + json_out = cb._dump_results(result) + assert 'SENTINEL' not in json_out + assert 'LEFTIN' in json_out + + def test_mixed_keys(self): + cb = CallbackBase() + result = {3: 'pi', + 'tau': 6} + json_out = cb._dump_results(result) + round_trip_result = json.loads(json_out) + assert len(round_trip_result) == 2 + assert '3' in round_trip_result + assert 'tau' in round_trip_result + assert round_trip_result['3'] == 'pi' + assert round_trip_result['tau'] == 6 + + +class TestCallbackDiff(unittest.TestCase): + + def setUp(self): + self.cb = CallbackBase() + + def _strip_color(self, s): + return re.sub('\033\\[[^m]*m', '', s) + + def test_difflist(self): + # TODO: split into smaller tests? + difflist = [{'before': u'preface\nThe Before String\npostscript', + 'after': u'preface\nThe After String\npostscript', + 'before_header': u'just before', + 'after_header': u'just after' + }, + {'before': u'preface\nThe Before String\npostscript', + 'after': u'preface\nThe After String\npostscript', + }, + {'src_binary': 'chicane'}, + {'dst_binary': 'chicanery'}, + {'dst_larger': 1}, + {'src_larger': 2}, + {'prepared': u'what does prepared do?'}, + {'before_header': u'just before'}, + {'after_header': u'just after'}] + + res = self.cb._get_diff(difflist) + + self.assertIn(u'Before String', res) + self.assertIn(u'After String', res) + self.assertIn(u'just before', res) + self.assertIn(u'just after', res) + + def test_simple_diff(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': 'one\ntwo\nthree\n', + 'after': 'one\nthree\nfour\n', + })), + textwrap.dedent('''\ + --- before: somefile.txt + +++ after: generated from template somefile.j2 + @@ -1,3 +1,3 @@ + one + -two + three + +four + + ''')) + + def test_new_file(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': '', + 'after': 'one\ntwo\nthree\n', + })), + textwrap.dedent('''\ + --- before: somefile.txt + +++ after: generated from template somefile.j2 + @@ -0,0 +1,3 @@ + +one + +two + +three + + ''')) + + def test_clear_file(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': 'one\ntwo\nthree\n', + 'after': '', + })), + textwrap.dedent('''\ + --- before: somefile.txt + +++ after: generated from template somefile.j2 + @@ -1,3 +0,0 @@ + -one + -two + -three + + ''')) + + def test_no_trailing_newline_before(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': 'one\ntwo\nthree', + 'after': 'one\ntwo\nthree\n', + })), + textwrap.dedent('''\ + --- before: somefile.txt + +++ after: generated from template somefile.j2 + @@ -1,3 +1,3 @@ + one + two + -three + \\ No newline at end of file + +three + + ''')) + + def test_no_trailing_newline_after(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': 'one\ntwo\nthree\n', + 'after': 'one\ntwo\nthree', + })), + textwrap.dedent('''\ + --- before: somefile.txt + +++ after: generated from template somefile.j2 + @@ -1,3 +1,3 @@ + one + two + -three + +three + \\ No newline at end of file + + ''')) + + def test_no_trailing_newline_both(self): + self.assertMultiLineEqual( + self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': 'one\ntwo\nthree', + 'after': 'one\ntwo\nthree', + }), + '') + + def test_no_trailing_newline_both_with_some_changes(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before_header': 'somefile.txt', + 'after_header': 'generated from template somefile.j2', + 'before': 'one\ntwo\nthree', + 'after': 'one\nfive\nthree', + })), + textwrap.dedent('''\ + --- before: somefile.txt + +++ after: generated from template somefile.j2 + @@ -1,3 +1,3 @@ + one + -two + +five + three + \\ No newline at end of file + + ''')) + + def test_diff_dicts(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before': dict(one=1, two=2, three=3), + 'after': dict(one=1, three=3, four=4), + })), + textwrap.dedent('''\ + --- before + +++ after + @@ -1,5 +1,5 @@ + { + + "four": 4, + "one": 1, + - "three": 3, + - "two": 2 + + "three": 3 + } + + ''')) + + def test_diff_before_none(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before': None, + 'after': 'one line\n', + })), + textwrap.dedent('''\ + --- before + +++ after + @@ -0,0 +1 @@ + +one line + + ''')) + + def test_diff_after_none(self): + self.assertMultiLineEqual( + self._strip_color(self.cb._get_diff({ + 'before': 'one line\n', + 'after': None, + })), + textwrap.dedent('''\ + --- before + +++ after + @@ -1 +0,0 @@ + -one line + + ''')) + + +class TestCallbackOnMethods(unittest.TestCase): + def _find_on_methods(self, callback): + cb_dir = dir(callback) + method_names = [x for x in cb_dir if '_on_' in x] + methods = [getattr(callback, mn) for mn in method_names] + return methods + + def test_are_methods(self): + cb = CallbackBase() + for method in self._find_on_methods(cb): + self.assertIsInstance(method, types.MethodType) + + def test_on_any(self): + cb = CallbackBase() + cb.v2_on_any('whatever', some_keyword='blippy') + cb.on_any('whatever', some_keyword='blippy') diff --git a/test/units/plugins/connection/__init__.py b/test/units/plugins/connection/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/connection/__init__.py diff --git a/test/units/plugins/connection/test_connection.py b/test/units/plugins/connection/test_connection.py new file mode 100644 index 00000000..17c2e085 --- /dev/null +++ b/test/units/plugins/connection/test_connection.py @@ -0,0 +1,169 @@ +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from io import StringIO +import sys +import pytest + +from units.compat import mock +from units.compat import unittest +from units.compat.mock import MagicMock +from units.compat.mock import patch +from ansible.errors import AnsibleError +from ansible.playbook.play_context import PlayContext +from ansible.plugins.connection import ConnectionBase +from ansible.plugins.loader import become_loader + + +class TestConnectionBaseClass(unittest.TestCase): + + def setUp(self): + self.play_context = PlayContext() + self.play_context.prompt = ( + '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: ' + ) + self.in_stream = StringIO() + + def tearDown(self): + pass + + def test_subclass_error(self): + class ConnectionModule1(ConnectionBase): + pass + with self.assertRaises(TypeError): + ConnectionModule1() # pylint: disable=abstract-class-instantiated + + class ConnectionModule2(ConnectionBase): + def get(self, key): + super(ConnectionModule2, self).get(key) + + with self.assertRaises(TypeError): + ConnectionModule2() # pylint: disable=abstract-class-instantiated + + def test_subclass_success(self): + class ConnectionModule3(ConnectionBase): + + @property + def transport(self): + pass + + def _connect(self): + pass + + def exec_command(self): + pass + + def put_file(self): + pass + + def fetch_file(self): + pass + + def close(self): + pass + + self.assertIsInstance(ConnectionModule3(self.play_context, self.in_stream), ConnectionModule3) + + def test_check_password_prompt(self): + local = ( + b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n' + b'BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq\n' + ) + + ssh_pipelining_vvvv = b''' +debug3: mux_master_read_cb: channel 1 packet type 0x10000002 len 251 +debug2: process_mux_new_session: channel 1: request tty 0, X 1, agent 1, subsys 0, term "xterm-256color", cmd "/bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0'", env 0 +debug3: process_mux_new_session: got fds stdin 9, stdout 10, stderr 11 +debug2: client_session2_setup: id 2 +debug1: Sending command: /bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0' +debug2: channel 2: request exec confirm 1 +debug2: channel 2: rcvd ext data 67 +[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: debug2: channel 2: written 67 to efd 11 +BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq +debug3: receive packet: type 98 +''' # noqa + + ssh_nopipelining_vvvv = b''' +debug3: mux_master_read_cb: channel 1 packet type 0x10000002 len 251 +debug2: process_mux_new_session: channel 1: request tty 1, X 1, agent 1, subsys 0, term "xterm-256color", cmd "/bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0'", env 0 +debug3: mux_client_request_session: session request sent +debug3: send packet: type 98 +debug1: Sending command: /bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0' +debug2: channel 2: request exec confirm 1 +debug2: exec request accepted on channel 2 +[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: debug3: receive packet: type 2 +debug3: Received SSH2_MSG_IGNORE +debug3: Received SSH2_MSG_IGNORE + +BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq +debug3: receive packet: type 98 +''' # noqa + + ssh_novvvv = ( + b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n' + b'BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq\n' + ) + + dns_issue = ( + b'timeout waiting for privilege escalation password prompt:\n' + b'sudo: sudo: unable to resolve host tcloud014\n' + b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n' + b'BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq\n' + ) + + nothing = b'' + + in_front = b''' +debug1: Sending command: /bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo +''' + + class ConnectionFoo(ConnectionBase): + + @property + def transport(self): + pass + + def _connect(self): + pass + + def exec_command(self): + pass + + def put_file(self): + pass + + def fetch_file(self): + pass + + def close(self): + pass + + c = ConnectionFoo(self.play_context, self.in_stream) + c.set_become_plugin(become_loader.get('sudo')) + c.become.prompt = '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: ' + + self.assertTrue(c.check_password_prompt(local)) + self.assertTrue(c.check_password_prompt(ssh_pipelining_vvvv)) + self.assertTrue(c.check_password_prompt(ssh_nopipelining_vvvv)) + self.assertTrue(c.check_password_prompt(ssh_novvvv)) + self.assertTrue(c.check_password_prompt(dns_issue)) + self.assertFalse(c.check_password_prompt(nothing)) + self.assertFalse(c.check_password_prompt(in_front)) diff --git a/test/units/plugins/connection/test_local.py b/test/units/plugins/connection/test_local.py new file mode 100644 index 00000000..e5525855 --- /dev/null +++ b/test/units/plugins/connection/test_local.py @@ -0,0 +1,40 @@ +# +# (c) 2020 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from io import StringIO +import pytest + +from units.compat import unittest +from ansible.plugins.connection import local +from ansible.playbook.play_context import PlayContext + + +class TestLocalConnectionClass(unittest.TestCase): + + def test_local_connection_module(self): + play_context = PlayContext() + play_context.prompt = ( + '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: ' + ) + in_stream = StringIO() + + self.assertIsInstance(local.Connection(play_context, in_stream), local.Connection) diff --git a/test/units/plugins/connection/test_paramiko.py b/test/units/plugins/connection/test_paramiko.py new file mode 100644 index 00000000..e3643b14 --- /dev/null +++ b/test/units/plugins/connection/test_paramiko.py @@ -0,0 +1,42 @@ +# +# (c) 2020 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from io import StringIO +import pytest + +from units.compat import unittest +from ansible.plugins.connection import paramiko_ssh +from ansible.playbook.play_context import PlayContext + + +class TestParamikoConnectionClass(unittest.TestCase): + + def test_paramiko_connection_module(self): + play_context = PlayContext() + play_context.prompt = ( + '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: ' + ) + in_stream = StringIO() + + self.assertIsInstance( + paramiko_ssh.Connection(play_context, in_stream), + paramiko_ssh.Connection) diff --git a/test/units/plugins/connection/test_psrp.py b/test/units/plugins/connection/test_psrp.py new file mode 100644 index 00000000..f6416751 --- /dev/null +++ b/test/units/plugins/connection/test_psrp.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +# (c) 2018, Jordan Borean <jborean@redhat.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest +import sys + +from io import StringIO +from units.compat.mock import MagicMock + +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import connection_loader +from ansible.utils.display import Display + + +@pytest.fixture(autouse=True) +def psrp_connection(): + """Imports the psrp connection plugin with a mocked pypsrp module for testing""" + + # Take a snapshot of sys.modules before we manipulate it + orig_modules = sys.modules.copy() + try: + fake_pypsrp = MagicMock() + fake_pypsrp.FEATURES = [ + 'wsman_locale', + 'wsman_read_timeout', + 'wsman_reconnections', + ] + + fake_wsman = MagicMock() + fake_wsman.AUTH_KWARGS = { + "certificate": ["certificate_key_pem", "certificate_pem"], + "credssp": ["credssp_auth_mechanism", "credssp_disable_tlsv1_2", + "credssp_minimum_version"], + "negotiate": ["negotiate_delegate", "negotiate_hostname_override", + "negotiate_send_cbt", "negotiate_service"], + "mock": ["mock_test1", "mock_test2"], + } + + sys.modules["pypsrp"] = fake_pypsrp + sys.modules["pypsrp.complex_objects"] = MagicMock() + sys.modules["pypsrp.exceptions"] = MagicMock() + sys.modules["pypsrp.host"] = MagicMock() + sys.modules["pypsrp.powershell"] = MagicMock() + sys.modules["pypsrp.shell"] = MagicMock() + sys.modules["pypsrp.wsman"] = fake_wsman + sys.modules["requests.exceptions"] = MagicMock() + + from ansible.plugins.connection import psrp + + # Take a copy of the original import state vars before we set to an ok import + orig_has_psrp = psrp.HAS_PYPSRP + orig_psrp_imp_err = psrp.PYPSRP_IMP_ERR + + yield psrp + + psrp.HAS_PYPSRP = orig_has_psrp + psrp.PYPSRP_IMP_ERR = orig_psrp_imp_err + finally: + # Restore sys.modules back to our pre-shenanigans + sys.modules = orig_modules + + +class TestConnectionPSRP(object): + + OPTIONS_DATA = ( + # default options + ( + {'_extras': {}}, + { + '_psrp_auth': 'negotiate', + '_psrp_cert_validation': True, + '_psrp_configuration_name': 'Microsoft.PowerShell', + '_psrp_connection_timeout': 30, + '_psrp_message_encryption': 'auto', + '_psrp_host': 'inventory_hostname', + '_psrp_conn_kwargs': { + 'server': 'inventory_hostname', + 'port': 5986, + 'username': None, + 'password': None, + 'ssl': True, + 'path': 'wsman', + 'auth': 'negotiate', + 'cert_validation': True, + 'connection_timeout': 30, + 'encryption': 'auto', + 'proxy': None, + 'no_proxy': False, + 'max_envelope_size': 153600, + 'operation_timeout': 20, + 'certificate_key_pem': None, + 'certificate_pem': None, + 'credssp_auth_mechanism': 'auto', + 'credssp_disable_tlsv1_2': False, + 'credssp_minimum_version': 2, + 'negotiate_delegate': None, + 'negotiate_hostname_override': None, + 'negotiate_send_cbt': True, + 'negotiate_service': 'WSMAN', + 'read_timeout': 30, + 'reconnection_backoff': 2.0, + 'reconnection_retries': 0, + }, + '_psrp_max_envelope_size': 153600, + '_psrp_ignore_proxy': False, + '_psrp_operation_timeout': 20, + '_psrp_pass': None, + '_psrp_path': 'wsman', + '_psrp_port': 5986, + '_psrp_proxy': None, + '_psrp_protocol': 'https', + '_psrp_user': None + }, + ), + # ssl=False when port defined to 5985 + ( + {'_extras': {}, 'ansible_port': '5985'}, + { + '_psrp_port': 5985, + '_psrp_protocol': 'http' + }, + ), + # ssl=True when port defined to not 5985 + ( + {'_extras': {}, 'ansible_port': 1234}, + { + '_psrp_port': 1234, + '_psrp_protocol': 'https' + }, + ), + # port 5986 when ssl=True + ( + {'_extras': {}, 'ansible_psrp_protocol': 'https'}, + { + '_psrp_port': 5986, + '_psrp_protocol': 'https' + }, + ), + # port 5985 when ssl=False + ( + {'_extras': {}, 'ansible_psrp_protocol': 'http'}, + { + '_psrp_port': 5985, + '_psrp_protocol': 'http' + }, + ), + # psrp extras + ( + {'_extras': {'ansible_psrp_mock_test1': True}}, + { + '_psrp_conn_kwargs': { + 'server': 'inventory_hostname', + 'port': 5986, + 'username': None, + 'password': None, + 'ssl': True, + 'path': 'wsman', + 'auth': 'negotiate', + 'cert_validation': True, + 'connection_timeout': 30, + 'encryption': 'auto', + 'proxy': None, + 'no_proxy': False, + 'max_envelope_size': 153600, + 'operation_timeout': 20, + 'certificate_key_pem': None, + 'certificate_pem': None, + 'credssp_auth_mechanism': 'auto', + 'credssp_disable_tlsv1_2': False, + 'credssp_minimum_version': 2, + 'negotiate_delegate': None, + 'negotiate_hostname_override': None, + 'negotiate_send_cbt': True, + 'negotiate_service': 'WSMAN', + 'read_timeout': 30, + 'reconnection_backoff': 2.0, + 'reconnection_retries': 0, + 'mock_test1': True + }, + }, + ), + # cert validation through string repr of bool + ( + {'_extras': {}, 'ansible_psrp_cert_validation': 'ignore'}, + { + '_psrp_cert_validation': False + }, + ), + # cert validation path + ( + {'_extras': {}, 'ansible_psrp_cert_trust_path': '/path/cert.pem'}, + { + '_psrp_cert_validation': '/path/cert.pem' + }, + ), + ) + + # pylint bug: https://github.com/PyCQA/pylint/issues/511 + # pylint: disable=undefined-variable + @pytest.mark.parametrize('options, expected', + ((o, e) for o, e in OPTIONS_DATA)) + def test_set_options(self, options, expected): + pc = PlayContext() + new_stdin = StringIO() + + conn = connection_loader.get('psrp', pc, new_stdin) + conn.set_options(var_options=options) + conn._build_kwargs() + + for attr, expected in expected.items(): + actual = getattr(conn, attr) + assert actual == expected, \ + "psrp attr '%s', actual '%s' != expected '%s'"\ + % (attr, actual, expected) + + def test_set_invalid_extras_options(self, monkeypatch): + pc = PlayContext() + new_stdin = StringIO() + + conn = connection_loader.get('psrp', pc, new_stdin) + conn.set_options(var_options={'_extras': {'ansible_psrp_mock_test3': True}}) + + mock_display = MagicMock() + monkeypatch.setattr(Display, "warning", mock_display) + conn._build_kwargs() + + assert mock_display.call_args[0][0] == \ + 'ansible_psrp_mock_test3 is unsupported by the current psrp version installed' diff --git a/test/units/plugins/connection/test_ssh.py b/test/units/plugins/connection/test_ssh.py new file mode 100644 index 00000000..cfe7fcb6 --- /dev/null +++ b/test/units/plugins/connection/test_ssh.py @@ -0,0 +1,688 @@ +# -*- coding: utf-8 -*- +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from io import StringIO +import pytest + + +from ansible import constants as C +from ansible.errors import AnsibleAuthenticationFailure +from units.compat import unittest +from units.compat.mock import patch, MagicMock, PropertyMock +from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound +from ansible.module_utils.compat.selectors import SelectorKey, EVENT_READ +from ansible.module_utils.six.moves import shlex_quote +from ansible.module_utils._text import to_bytes +from ansible.playbook.play_context import PlayContext +from ansible.plugins.connection import ssh +from ansible.plugins.loader import connection_loader, become_loader + + +class TestConnectionBaseClass(unittest.TestCase): + + def test_plugins_connection_ssh_module(self): + play_context = PlayContext() + play_context.prompt = ( + '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: ' + ) + in_stream = StringIO() + + self.assertIsInstance(ssh.Connection(play_context, in_stream), ssh.Connection) + + def test_plugins_connection_ssh_basic(self): + pc = PlayContext() + new_stdin = StringIO() + conn = ssh.Connection(pc, new_stdin) + + # connect just returns self, so assert that + res = conn._connect() + self.assertEqual(conn, res) + + ssh.SSHPASS_AVAILABLE = False + self.assertFalse(conn._sshpass_available()) + + ssh.SSHPASS_AVAILABLE = True + self.assertTrue(conn._sshpass_available()) + + with patch('subprocess.Popen') as p: + ssh.SSHPASS_AVAILABLE = None + p.return_value = MagicMock() + self.assertTrue(conn._sshpass_available()) + + ssh.SSHPASS_AVAILABLE = None + p.return_value = None + p.side_effect = OSError() + self.assertFalse(conn._sshpass_available()) + + conn.close() + self.assertFalse(conn._connected) + + def test_plugins_connection_ssh__build_command(self): + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('ssh', pc, new_stdin) + conn._build_command('ssh') + + def test_plugins_connection_ssh_exec_command(self): + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('ssh', pc, new_stdin) + + conn._build_command = MagicMock() + conn._build_command.return_value = 'ssh something something' + conn._run = MagicMock() + conn._run.return_value = (0, 'stdout', 'stderr') + conn.get_option = MagicMock() + conn.get_option.return_value = True + + res, stdout, stderr = conn.exec_command('ssh') + res, stdout, stderr = conn.exec_command('ssh', 'this is some data') + + def test_plugins_connection_ssh__examine_output(self): + pc = PlayContext() + new_stdin = StringIO() + + conn = connection_loader.get('ssh', pc, new_stdin) + conn.set_become_plugin(become_loader.get('sudo')) + + conn.check_password_prompt = MagicMock() + conn.check_become_success = MagicMock() + conn.check_incorrect_password = MagicMock() + conn.check_missing_password = MagicMock() + + def _check_password_prompt(line): + if b'foo' in line: + return True + return False + + def _check_become_success(line): + if b'BECOME-SUCCESS-abcdefghijklmnopqrstuvxyz' in line: + return True + return False + + def _check_incorrect_password(line): + if b'incorrect password' in line: + return True + return False + + def _check_missing_password(line): + if b'bad password' in line: + return True + return False + + conn.become.check_password_prompt = MagicMock(side_effect=_check_password_prompt) + conn.become.check_become_success = MagicMock(side_effect=_check_become_success) + conn.become.check_incorrect_password = MagicMock(side_effect=_check_incorrect_password) + conn.become.check_missing_password = MagicMock(side_effect=_check_missing_password) + + # test examining output for prompt + conn._flags = dict( + become_prompt=False, + become_success=False, + become_error=False, + become_nopasswd_error=False, + ) + + pc.prompt = True + conn.become.prompt = True + + def get_option(option): + if option == 'become_pass': + return 'password' + return None + + conn.become.get_option = get_option + output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\nfoo\nline 3\nthis should be the remainder', False) + self.assertEqual(output, b'line 1\nline 2\nline 3\n') + self.assertEqual(unprocessed, b'this should be the remainder') + self.assertTrue(conn._flags['become_prompt']) + self.assertFalse(conn._flags['become_success']) + self.assertFalse(conn._flags['become_error']) + self.assertFalse(conn._flags['become_nopasswd_error']) + + # test examining output for become prompt + conn._flags = dict( + become_prompt=False, + become_success=False, + become_error=False, + become_nopasswd_error=False, + ) + + pc.prompt = False + conn.become.prompt = False + pc.success_key = u'BECOME-SUCCESS-abcdefghijklmnopqrstuvxyz' + conn.become.success = u'BECOME-SUCCESS-abcdefghijklmnopqrstuvxyz' + output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\nBECOME-SUCCESS-abcdefghijklmnopqrstuvxyz\nline 3\n', False) + self.assertEqual(output, b'line 1\nline 2\nline 3\n') + self.assertEqual(unprocessed, b'') + self.assertFalse(conn._flags['become_prompt']) + self.assertTrue(conn._flags['become_success']) + self.assertFalse(conn._flags['become_error']) + self.assertFalse(conn._flags['become_nopasswd_error']) + + # test examining output for become failure + conn._flags = dict( + become_prompt=False, + become_success=False, + become_error=False, + become_nopasswd_error=False, + ) + + pc.prompt = False + conn.become.prompt = False + pc.success_key = None + output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\nincorrect password\n', True) + self.assertEqual(output, b'line 1\nline 2\nincorrect password\n') + self.assertEqual(unprocessed, b'') + self.assertFalse(conn._flags['become_prompt']) + self.assertFalse(conn._flags['become_success']) + self.assertTrue(conn._flags['become_error']) + self.assertFalse(conn._flags['become_nopasswd_error']) + + # test examining output for missing password + conn._flags = dict( + become_prompt=False, + become_success=False, + become_error=False, + become_nopasswd_error=False, + ) + + pc.prompt = False + conn.become.prompt = False + pc.success_key = None + output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nbad password\n', True) + self.assertEqual(output, b'line 1\nbad password\n') + self.assertEqual(unprocessed, b'') + self.assertFalse(conn._flags['become_prompt']) + self.assertFalse(conn._flags['become_success']) + self.assertFalse(conn._flags['become_error']) + self.assertTrue(conn._flags['become_nopasswd_error']) + + @patch('time.sleep') + @patch('os.path.exists') + def test_plugins_connection_ssh_put_file(self, mock_ospe, mock_sleep): + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('ssh', pc, new_stdin) + conn._build_command = MagicMock() + conn._bare_run = MagicMock() + + mock_ospe.return_value = True + conn._build_command.return_value = 'some command to run' + conn._bare_run.return_value = (0, '', '') + conn.host = "some_host" + + C.ANSIBLE_SSH_RETRIES = 9 + + # Test with C.DEFAULT_SCP_IF_SSH set to smart + # Test when SFTP works + C.DEFAULT_SCP_IF_SSH = 'smart' + expected_in_data = b' '.join((b'put', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n' + conn.put_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False) + + # Test when SFTP doesn't work but SCP does + conn._bare_run.side_effect = [(1, 'stdout', 'some errors'), (0, '', '')] + conn.put_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', None, checkrc=False) + conn._bare_run.side_effect = None + + # test with C.DEFAULT_SCP_IF_SSH enabled + C.DEFAULT_SCP_IF_SSH = True + conn.put_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', None, checkrc=False) + + conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩') + conn._bare_run.assert_called_with('some command to run', None, checkrc=False) + + # test with C.DEFAULT_SCP_IF_SSH disabled + C.DEFAULT_SCP_IF_SSH = False + expected_in_data = b' '.join((b'put', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n' + conn.put_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False) + + expected_in_data = b' '.join((b'put', + to_bytes(shlex_quote('/path/to/in/file/with/unicode-fö〩')), + to_bytes(shlex_quote('/path/to/dest/file/with/unicode-fö〩')))) + b'\n' + conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩') + conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False) + + # test that a non-zero rc raises an error + conn._bare_run.return_value = (1, 'stdout', 'some errors') + self.assertRaises(AnsibleError, conn.put_file, '/path/to/bad/file', '/remote/path/to/file') + + # test that a not-found path raises an error + mock_ospe.return_value = False + conn._bare_run.return_value = (0, 'stdout', '') + self.assertRaises(AnsibleFileNotFound, conn.put_file, '/path/to/bad/file', '/remote/path/to/file') + + @patch('time.sleep') + def test_plugins_connection_ssh_fetch_file(self, mock_sleep): + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('ssh', pc, new_stdin) + conn._build_command = MagicMock() + conn._bare_run = MagicMock() + conn._load_name = 'ssh' + + conn._build_command.return_value = 'some command to run' + conn._bare_run.return_value = (0, '', '') + conn.host = "some_host" + + C.ANSIBLE_SSH_RETRIES = 9 + + # Test with C.DEFAULT_SCP_IF_SSH set to smart + # Test when SFTP works + C.DEFAULT_SCP_IF_SSH = 'smart' + expected_in_data = b' '.join((b'get', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n' + conn.set_options({}) + conn.fetch_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False) + + # Test when SFTP doesn't work but SCP does + conn._bare_run.side_effect = [(1, 'stdout', 'some errors'), (0, '', '')] + conn.fetch_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', None, checkrc=False) + conn._bare_run.side_effect = None + + # test with C.DEFAULT_SCP_IF_SSH enabled + C.DEFAULT_SCP_IF_SSH = True + conn.fetch_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', None, checkrc=False) + + conn.fetch_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩') + conn._bare_run.assert_called_with('some command to run', None, checkrc=False) + + # test with C.DEFAULT_SCP_IF_SSH disabled + C.DEFAULT_SCP_IF_SSH = False + expected_in_data = b' '.join((b'get', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n' + conn.fetch_file('/path/to/in/file', '/path/to/dest/file') + conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False) + + expected_in_data = b' '.join((b'get', + to_bytes(shlex_quote('/path/to/in/file/with/unicode-fö〩')), + to_bytes(shlex_quote('/path/to/dest/file/with/unicode-fö〩')))) + b'\n' + conn.fetch_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩') + conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False) + + # test that a non-zero rc raises an error + conn._bare_run.return_value = (1, 'stdout', 'some errors') + self.assertRaises(AnsibleError, conn.fetch_file, '/path/to/bad/file', '/remote/path/to/file') + + +class MockSelector(object): + def __init__(self): + self.files_watched = 0 + self.register = MagicMock(side_effect=self._register) + self.unregister = MagicMock(side_effect=self._unregister) + self.close = MagicMock() + self.get_map = MagicMock(side_effect=self._get_map) + self.select = MagicMock() + + def _register(self, *args, **kwargs): + self.files_watched += 1 + + def _unregister(self, *args, **kwargs): + self.files_watched -= 1 + + def _get_map(self, *args, **kwargs): + return self.files_watched + + +@pytest.fixture +def mock_run_env(request, mocker): + pc = PlayContext() + new_stdin = StringIO() + + conn = connection_loader.get('ssh', pc, new_stdin) + conn.set_become_plugin(become_loader.get('sudo')) + conn._send_initial_data = MagicMock() + conn._examine_output = MagicMock() + conn._terminate_process = MagicMock() + conn._load_name = 'ssh' + conn.sshpass_pipe = [MagicMock(), MagicMock()] + + request.cls.pc = pc + request.cls.conn = conn + + mock_popen_res = MagicMock() + mock_popen_res.poll = MagicMock() + mock_popen_res.wait = MagicMock() + mock_popen_res.stdin = MagicMock() + mock_popen_res.stdin.fileno.return_value = 1000 + mock_popen_res.stdout = MagicMock() + mock_popen_res.stdout.fileno.return_value = 1001 + mock_popen_res.stderr = MagicMock() + mock_popen_res.stderr.fileno.return_value = 1002 + mock_popen_res.returncode = 0 + request.cls.mock_popen_res = mock_popen_res + + mock_popen = mocker.patch('subprocess.Popen', return_value=mock_popen_res) + request.cls.mock_popen = mock_popen + + request.cls.mock_selector = MockSelector() + mocker.patch('ansible.module_utils.compat.selectors.DefaultSelector', lambda: request.cls.mock_selector) + + request.cls.mock_openpty = mocker.patch('pty.openpty') + + mocker.patch('fcntl.fcntl') + mocker.patch('os.write') + mocker.patch('os.close') + + +@pytest.mark.usefixtures('mock_run_env') +class TestSSHConnectionRun(object): + # FIXME: + # These tests are little more than a smoketest. Need to enhance them + # a bit to check that they're calling the relevant functions and making + # complete coverage of the code paths + def test_no_escalation(self): + self.mock_popen_res.stdout.read.side_effect = [b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"my_stderr"] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data") + assert return_code == 0 + assert b_stdout == b'my_stdout\nsecond_line' + assert b_stderr == b'my_stderr' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is input data' + + def test_with_password(self): + # test with a password set to trigger the sshpass write + self.pc.password = '12345' + self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run(["ssh", "is", "a", "cmd"], "this is more data") + assert return_code == 0 + assert b_stdout == b'some data' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is more data' + + def _password_with_prompt_examine_output(self, sourice, state, b_chunk, sudoable): + if state == 'awaiting_prompt': + self.conn._flags['become_prompt'] = True + elif state == 'awaiting_escalation': + self.conn._flags['become_success'] = True + return (b'', b'') + + def test_password_with_prompt(self): + # test with password prompting enabled + self.pc.password = None + self.conn.become.prompt = b'Password:' + self.conn._examine_output.side_effect = self._password_with_prompt_examine_output + self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"Success", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ), + (SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data") + assert return_code == 0 + assert b_stdout == b'' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is input data' + + def test_password_with_become(self): + # test with some become settings + self.pc.prompt = b'Password:' + self.conn.become.prompt = b'Password:' + self.pc.become = True + self.pc.success_key = 'BECOME-SUCCESS-abcdefg' + self.conn.become._id = 'abcdefg' + self.conn._examine_output.side_effect = self._password_with_prompt_examine_output + self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"BECOME-SUCCESS-abcdefg", b"abc"] + self.mock_popen_res.stderr.read.side_effect = [b"123"] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data") + self.mock_popen_res.stdin.flush.assert_called_once_with() + assert return_code == 0 + assert b_stdout == b'abc' + assert b_stderr == b'123' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is input data' + + def test_pasword_without_data(self): + # simulate no data input but Popen using new pty's fails + self.mock_popen.return_value = None + self.mock_popen.side_effect = [OSError(), self.mock_popen_res] + + # simulate no data input + self.mock_openpty.return_value = (98, 99) + self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "") + assert return_code == 0 + assert b_stdout == b'some data' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is False + + +@pytest.mark.usefixtures('mock_run_env') +class TestSSHConnectionRetries(object): + def test_incorrect_password(self, monkeypatch): + monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) + monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 5) + monkeypatch.setattr('time.sleep', lambda x: None) + + self.mock_popen_res.stdout.read.side_effect = [b''] + self.mock_popen_res.stderr.read.side_effect = [b'Permission denied, please try again.\r\n'] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[5] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + ] + + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = [b'sshpass', b'-d41', b'ssh', b'-C'] + self.conn.get_option = MagicMock() + self.conn.get_option.return_value = True + + exception_info = pytest.raises(AnsibleAuthenticationFailure, self.conn.exec_command, 'sshpass', 'some data') + assert exception_info.value.message == ('Invalid/incorrect username/password. Skipping remaining 5 retries to prevent account lockout: ' + 'Permission denied, please try again.') + assert self.mock_popen.call_count == 1 + + def test_retry_then_success(self, monkeypatch): + monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) + monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 3) + + monkeypatch.setattr('time.sleep', lambda x: None) + + self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [] + ] + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + self.conn.get_option = MagicMock() + self.conn.get_option.return_value = True + + return_code, b_stdout, b_stderr = self.conn.exec_command('ssh', 'some data') + assert return_code == 0 + assert b_stdout == b'my_stdout\nsecond_line' + assert b_stderr == b'my_stderr' + + def test_multiple_failures(self, monkeypatch): + monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) + monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 9) + + monkeypatch.setattr('time.sleep', lambda x: None) + + self.mock_popen_res.stdout.read.side_effect = [b""] * 10 + self.mock_popen_res.stderr.read.side_effect = [b""] * 10 + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 30) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + ] * 10 + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + self.conn.get_option = MagicMock() + self.conn.get_option.return_value = True + + pytest.raises(AnsibleConnectionFailure, self.conn.exec_command, 'ssh', 'some data') + assert self.mock_popen.call_count == 10 + + def test_abitrary_exceptions(self, monkeypatch): + monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) + monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 9) + + monkeypatch.setattr('time.sleep', lambda x: None) + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'ssh' + self.conn.get_option = MagicMock() + self.conn.get_option.return_value = True + + self.mock_popen.side_effect = [Exception('bad')] * 10 + pytest.raises(Exception, self.conn.exec_command, 'ssh', 'some data') + assert self.mock_popen.call_count == 10 + + def test_put_file_retries(self, monkeypatch): + monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) + monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 3) + + monkeypatch.setattr('time.sleep', lambda x: None) + monkeypatch.setattr('ansible.plugins.connection.ssh.os.path.exists', lambda x: True) + + self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 4 + [0] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [] + ] + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'sftp' + + return_code, b_stdout, b_stderr = self.conn.put_file('/path/to/in/file', '/path/to/dest/file') + assert return_code == 0 + assert b_stdout == b"my_stdout\nsecond_line" + assert b_stderr == b"my_stderr" + assert self.mock_popen.call_count == 2 + + def test_fetch_file_retries(self, monkeypatch): + monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) + monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 3) + + monkeypatch.setattr('time.sleep', lambda x: None) + monkeypatch.setattr('ansible.plugins.connection.ssh.os.path.exists', lambda x: True) + + self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"] + type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 4 + [0] * 4) + + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [] + ] + self.mock_selector.get_map.side_effect = lambda: True + + self.conn._build_command = MagicMock() + self.conn._build_command.return_value = 'sftp' + + return_code, b_stdout, b_stderr = self.conn.fetch_file('/path/to/in/file', '/path/to/dest/file') + assert return_code == 0 + assert b_stdout == b"my_stdout\nsecond_line" + assert b_stderr == b"my_stderr" + assert self.mock_popen.call_count == 2 diff --git a/test/units/plugins/connection/test_winrm.py b/test/units/plugins/connection/test_winrm.py new file mode 100644 index 00000000..67bfd9ae --- /dev/null +++ b/test/units/plugins/connection/test_winrm.py @@ -0,0 +1,431 @@ +# -*- coding: utf-8 -*- +# (c) 2018, Jordan Borean <jborean@redhat.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from io import StringIO + +from units.compat.mock import MagicMock +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_bytes +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import connection_loader +from ansible.plugins.connection import winrm + +pytest.importorskip("winrm") + + +class TestConnectionWinRM(object): + + OPTIONS_DATA = ( + # default options + ( + {'_extras': {}}, + {}, + { + '_kerb_managed': False, + '_kinit_cmd': 'kinit', + '_winrm_connection_timeout': None, + '_winrm_host': 'inventory_hostname', + '_winrm_kwargs': {'username': None, 'password': None}, + '_winrm_pass': None, + '_winrm_path': '/wsman', + '_winrm_port': 5986, + '_winrm_scheme': 'https', + '_winrm_transport': ['ssl'], + '_winrm_user': None + }, + False + ), + # http through port + ( + {'_extras': {}, 'ansible_port': 5985}, + {}, + { + '_winrm_kwargs': {'username': None, 'password': None}, + '_winrm_port': 5985, + '_winrm_scheme': 'http', + '_winrm_transport': ['plaintext'], + }, + False + ), + # kerberos user with kerb present + ( + {'_extras': {}, 'ansible_user': 'user@domain.com'}, + {}, + { + '_kerb_managed': False, + '_kinit_cmd': 'kinit', + '_winrm_kwargs': {'username': 'user@domain.com', + 'password': None}, + '_winrm_pass': None, + '_winrm_transport': ['kerberos', 'ssl'], + '_winrm_user': 'user@domain.com' + }, + True + ), + # kerberos user without kerb present + ( + {'_extras': {}, 'ansible_user': 'user@domain.com'}, + {}, + { + '_kerb_managed': False, + '_kinit_cmd': 'kinit', + '_winrm_kwargs': {'username': 'user@domain.com', + 'password': None}, + '_winrm_pass': None, + '_winrm_transport': ['ssl'], + '_winrm_user': 'user@domain.com' + }, + False + ), + # kerberos user with managed ticket (implicit) + ( + {'_extras': {}, 'ansible_user': 'user@domain.com'}, + {'remote_password': 'pass'}, + { + '_kerb_managed': True, + '_kinit_cmd': 'kinit', + '_winrm_kwargs': {'username': 'user@domain.com', + 'password': 'pass'}, + '_winrm_pass': 'pass', + '_winrm_transport': ['kerberos', 'ssl'], + '_winrm_user': 'user@domain.com' + }, + True + ), + # kerb with managed ticket (explicit) + ( + {'_extras': {}, 'ansible_user': 'user@domain.com', + 'ansible_winrm_kinit_mode': 'managed'}, + {'password': 'pass'}, + { + '_kerb_managed': True, + }, + True + ), + # kerb with unmanaged ticket (explicit)) + ( + {'_extras': {}, 'ansible_user': 'user@domain.com', + 'ansible_winrm_kinit_mode': 'manual'}, + {'password': 'pass'}, + { + '_kerb_managed': False, + }, + True + ), + # transport override (single) + ( + {'_extras': {}, 'ansible_user': 'user@domain.com', + 'ansible_winrm_transport': 'ntlm'}, + {}, + { + '_winrm_kwargs': {'username': 'user@domain.com', + 'password': None}, + '_winrm_pass': None, + '_winrm_transport': ['ntlm'], + }, + False + ), + # transport override (list) + ( + {'_extras': {}, 'ansible_user': 'user@domain.com', + 'ansible_winrm_transport': ['ntlm', 'certificate']}, + {}, + { + '_winrm_kwargs': {'username': 'user@domain.com', + 'password': None}, + '_winrm_pass': None, + '_winrm_transport': ['ntlm', 'certificate'], + }, + False + ), + # winrm extras + ( + {'_extras': {'ansible_winrm_server_cert_validation': 'ignore', + 'ansible_winrm_service': 'WSMAN'}}, + {}, + { + '_winrm_kwargs': {'username': None, 'password': None, + 'server_cert_validation': 'ignore', + 'service': 'WSMAN'}, + }, + False + ), + # direct override + ( + {'_extras': {}, 'ansible_winrm_connection_timeout': 5}, + {'connection_timeout': 10}, + { + '_winrm_connection_timeout': 10, + }, + False + ), + # password as ansible_password + ( + {'_extras': {}, 'ansible_password': 'pass'}, + {}, + { + '_winrm_pass': 'pass', + '_winrm_kwargs': {'username': None, 'password': 'pass'} + }, + False + ), + # password as ansible_winrm_pass + ( + {'_extras': {}, 'ansible_winrm_pass': 'pass'}, + {}, + { + '_winrm_pass': 'pass', + '_winrm_kwargs': {'username': None, 'password': 'pass'} + }, + False + ), + + # password as ansible_winrm_password + ( + {'_extras': {}, 'ansible_winrm_password': 'pass'}, + {}, + { + '_winrm_pass': 'pass', + '_winrm_kwargs': {'username': None, 'password': 'pass'} + }, + False + ), + ) + + # pylint bug: https://github.com/PyCQA/pylint/issues/511 + # pylint: disable=undefined-variable + @pytest.mark.parametrize('options, direct, expected, kerb', + ((o, d, e, k) for o, d, e, k in OPTIONS_DATA)) + def test_set_options(self, options, direct, expected, kerb): + winrm.HAVE_KERBEROS = kerb + + pc = PlayContext() + new_stdin = StringIO() + + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options=options, direct=direct) + conn._build_winrm_kwargs() + + for attr, expected in expected.items(): + actual = getattr(conn, attr) + assert actual == expected, \ + "winrm attr '%s', actual '%s' != expected '%s'"\ + % (attr, actual, expected) + + +class TestWinRMKerbAuth(object): + + @pytest.mark.parametrize('options, expected', [ + [{"_extras": {}}, + (["kinit", "user@domain"],)], + [{"_extras": {}, 'ansible_winrm_kinit_cmd': 'kinit2'}, + (["kinit2", "user@domain"],)], + [{"_extras": {'ansible_winrm_kerberos_delegation': True}}, + (["kinit", "-f", "user@domain"],)], + ]) + def test_kinit_success_subprocess(self, monkeypatch, options, expected): + def mock_communicate(input=None, timeout=None): + return b"", b"" + + mock_popen = MagicMock() + mock_popen.return_value.communicate = mock_communicate + mock_popen.return_value.returncode = 0 + monkeypatch.setattr("subprocess.Popen", mock_popen) + + winrm.HAS_PEXPECT = False + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options=options) + conn._build_winrm_kwargs() + + conn._kerb_auth("user@domain", "pass") + mock_calls = mock_popen.mock_calls + assert len(mock_calls) == 1 + assert mock_calls[0][1] == expected + actual_env = mock_calls[0][2]['env'] + assert list(actual_env.keys()) == ['KRB5CCNAME'] + assert actual_env['KRB5CCNAME'].startswith("FILE:/") + + @pytest.mark.parametrize('options, expected', [ + [{"_extras": {}}, + ("kinit", ["user@domain"],)], + [{"_extras": {}, 'ansible_winrm_kinit_cmd': 'kinit2'}, + ("kinit2", ["user@domain"],)], + [{"_extras": {'ansible_winrm_kerberos_delegation': True}}, + ("kinit", ["-f", "user@domain"],)], + ]) + def test_kinit_success_pexpect(self, monkeypatch, options, expected): + pytest.importorskip("pexpect") + mock_pexpect = MagicMock() + mock_pexpect.return_value.exitstatus = 0 + monkeypatch.setattr("pexpect.spawn", mock_pexpect) + + winrm.HAS_PEXPECT = True + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options=options) + conn._build_winrm_kwargs() + + conn._kerb_auth("user@domain", "pass") + mock_calls = mock_pexpect.mock_calls + assert mock_calls[0][1] == expected + actual_env = mock_calls[0][2]['env'] + assert list(actual_env.keys()) == ['KRB5CCNAME'] + assert actual_env['KRB5CCNAME'].startswith("FILE:/") + assert mock_calls[0][2]['echo'] is False + assert mock_calls[1][0] == "().expect" + assert mock_calls[1][1] == (".*:",) + assert mock_calls[2][0] == "().sendline" + assert mock_calls[2][1] == ("pass",) + assert mock_calls[3][0] == "().read" + assert mock_calls[4][0] == "().wait" + + def test_kinit_with_missing_executable_subprocess(self, monkeypatch): + expected_err = "[Errno 2] No such file or directory: " \ + "'/fake/kinit': '/fake/kinit'" + mock_popen = MagicMock(side_effect=OSError(expected_err)) + + monkeypatch.setattr("subprocess.Popen", mock_popen) + + winrm.HAS_PEXPECT = False + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + options = {"_extras": {}, "ansible_winrm_kinit_cmd": "/fake/kinit"} + conn.set_options(var_options=options) + conn._build_winrm_kwargs() + + with pytest.raises(AnsibleConnectionFailure) as err: + conn._kerb_auth("user@domain", "pass") + assert str(err.value) == "Kerberos auth failure when calling " \ + "kinit cmd '/fake/kinit': %s" % expected_err + + def test_kinit_with_missing_executable_pexpect(self, monkeypatch): + pexpect = pytest.importorskip("pexpect") + + expected_err = "The command was not found or was not " \ + "executable: /fake/kinit" + mock_pexpect = \ + MagicMock(side_effect=pexpect.ExceptionPexpect(expected_err)) + + monkeypatch.setattr("pexpect.spawn", mock_pexpect) + + winrm.HAS_PEXPECT = True + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + options = {"_extras": {}, "ansible_winrm_kinit_cmd": "/fake/kinit"} + conn.set_options(var_options=options) + conn._build_winrm_kwargs() + + with pytest.raises(AnsibleConnectionFailure) as err: + conn._kerb_auth("user@domain", "pass") + assert str(err.value) == "Kerberos auth failure when calling " \ + "kinit cmd '/fake/kinit': %s" % expected_err + + def test_kinit_error_subprocess(self, monkeypatch): + expected_err = "kinit: krb5_parse_name: " \ + "Configuration file does not specify default realm" + + def mock_communicate(input=None, timeout=None): + return b"", to_bytes(expected_err) + + mock_popen = MagicMock() + mock_popen.return_value.communicate = mock_communicate + mock_popen.return_value.returncode = 1 + monkeypatch.setattr("subprocess.Popen", mock_popen) + + winrm.HAS_PEXPECT = False + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options={"_extras": {}}) + conn._build_winrm_kwargs() + + with pytest.raises(AnsibleConnectionFailure) as err: + conn._kerb_auth("invaliduser", "pass") + + assert str(err.value) == \ + "Kerberos auth failure for principal invaliduser with " \ + "subprocess: %s" % (expected_err) + + def test_kinit_error_pexpect(self, monkeypatch): + pytest.importorskip("pexpect") + + expected_err = "Configuration file does not specify default realm" + mock_pexpect = MagicMock() + mock_pexpect.return_value.expect = MagicMock(side_effect=OSError) + mock_pexpect.return_value.read.return_value = to_bytes(expected_err) + mock_pexpect.return_value.exitstatus = 1 + + monkeypatch.setattr("pexpect.spawn", mock_pexpect) + + winrm.HAS_PEXPECT = True + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options={"_extras": {}}) + conn._build_winrm_kwargs() + + with pytest.raises(AnsibleConnectionFailure) as err: + conn._kerb_auth("invaliduser", "pass") + + assert str(err.value) == \ + "Kerberos auth failure for principal invaliduser with " \ + "pexpect: %s" % (expected_err) + + def test_kinit_error_pass_in_output_subprocess(self, monkeypatch): + def mock_communicate(input=None, timeout=None): + return b"", b"Error with kinit\n" + input + + mock_popen = MagicMock() + mock_popen.return_value.communicate = mock_communicate + mock_popen.return_value.returncode = 1 + monkeypatch.setattr("subprocess.Popen", mock_popen) + + winrm.HAS_PEXPECT = False + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options={"_extras": {}}) + conn._build_winrm_kwargs() + + with pytest.raises(AnsibleConnectionFailure) as err: + conn._kerb_auth("username", "password") + assert str(err.value) == \ + "Kerberos auth failure for principal username with subprocess: " \ + "Error with kinit\n<redacted>" + + def test_kinit_error_pass_in_output_pexpect(self, monkeypatch): + pytest.importorskip("pexpect") + + mock_pexpect = MagicMock() + mock_pexpect.return_value.expect = MagicMock() + mock_pexpect.return_value.read.return_value = \ + b"Error with kinit\npassword\n" + mock_pexpect.return_value.exitstatus = 1 + + monkeypatch.setattr("pexpect.spawn", mock_pexpect) + + winrm.HAS_PEXPECT = True + pc = PlayContext() + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get('winrm', pc, new_stdin) + conn.set_options(var_options={"_extras": {}}) + conn._build_winrm_kwargs() + + with pytest.raises(AnsibleConnectionFailure) as err: + conn._kerb_auth("username", "password") + assert str(err.value) == \ + "Kerberos auth failure for principal username with pexpect: " \ + "Error with kinit\n<redacted>" diff --git a/test/units/plugins/filter/__init__.py b/test/units/plugins/filter/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/filter/__init__.py diff --git a/test/units/plugins/filter/test_core.py b/test/units/plugins/filter/test_core.py new file mode 100644 index 00000000..8a626d9a --- /dev/null +++ b/test/units/plugins/filter/test_core.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest + +from ansible.module_utils._text import to_native +from ansible.plugins.filter.core import to_uuid +from ansible.errors import AnsibleFilterError + + +UUID_DEFAULT_NAMESPACE_TEST_CASES = ( + ('example.com', 'ae780c3a-a3ab-53c2-bfb4-098da300b3fe'), + ('test.example', '8e437a35-c7c5-50ea-867c-5c254848dbc2'), + ('café.example', '8a99d6b1-fb8f-5f78-af86-879768589f56'), +) + +UUID_TEST_CASES = ( + ('361E6D51-FAEC-444A-9079-341386DA8E2E', 'example.com', 'ae780c3a-a3ab-53c2-bfb4-098da300b3fe'), + ('361E6D51-FAEC-444A-9079-341386DA8E2E', 'test.example', '8e437a35-c7c5-50ea-867c-5c254848dbc2'), + ('11111111-2222-3333-4444-555555555555', 'example.com', 'e776faa5-5299-55dc-9057-7a00e6be2364'), +) + + +@pytest.mark.parametrize('value, expected', UUID_DEFAULT_NAMESPACE_TEST_CASES) +def test_to_uuid_default_namespace(value, expected): + assert expected == to_uuid(value) + + +@pytest.mark.parametrize('namespace, value, expected', UUID_TEST_CASES) +def test_to_uuid(namespace, value, expected): + assert expected == to_uuid(value, namespace=namespace) + + +def test_to_uuid_invalid_namespace(): + with pytest.raises(AnsibleFilterError) as e: + to_uuid('example.com', namespace='11111111-2222-3333-4444-555555555') + assert 'Invalid value' in to_native(e.value) diff --git a/test/units/plugins/filter/test_mathstuff.py b/test/units/plugins/filter/test_mathstuff.py new file mode 100644 index 00000000..a0e78d33 --- /dev/null +++ b/test/units/plugins/filter/test_mathstuff.py @@ -0,0 +1,176 @@ +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type +import pytest + +from jinja2 import Environment + +import ansible.plugins.filter.mathstuff as ms +from ansible.errors import AnsibleFilterError, AnsibleFilterTypeError + + +UNIQUE_DATA = (([1, 3, 4, 2], sorted([1, 2, 3, 4])), + ([1, 3, 2, 4, 2, 3], sorted([1, 2, 3, 4])), + (['a', 'b', 'c', 'd'], sorted(['a', 'b', 'c', 'd'])), + (['a', 'a', 'd', 'b', 'a', 'd', 'c', 'b'], sorted(['a', 'b', 'c', 'd'])), + ) + +TWO_SETS_DATA = (([1, 2], [3, 4], ([], sorted([1, 2]), sorted([1, 2, 3, 4]), sorted([1, 2, 3, 4]))), + ([1, 2, 3], [5, 3, 4], ([3], sorted([1, 2]), sorted([1, 2, 5, 4]), sorted([1, 2, 3, 4, 5]))), + (['a', 'b', 'c'], ['d', 'c', 'e'], (['c'], sorted(['a', 'b']), sorted(['a', 'b', 'd', 'e']), sorted(['a', 'b', 'c', 'e', 'd']))), + ) + +env = Environment() + + +@pytest.mark.parametrize('data, expected', UNIQUE_DATA) +class TestUnique: + def test_unhashable(self, data, expected): + assert sorted(ms.unique(env, list(data))) == expected + + def test_hashable(self, data, expected): + assert sorted(ms.unique(env, tuple(data))) == expected + + +@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA) +class TestIntersect: + def test_unhashable(self, dataset1, dataset2, expected): + assert sorted(ms.intersect(env, list(dataset1), list(dataset2))) == expected[0] + + def test_hashable(self, dataset1, dataset2, expected): + assert sorted(ms.intersect(env, tuple(dataset1), tuple(dataset2))) == expected[0] + + +@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA) +class TestDifference: + def test_unhashable(self, dataset1, dataset2, expected): + assert sorted(ms.difference(env, list(dataset1), list(dataset2))) == expected[1] + + def test_hashable(self, dataset1, dataset2, expected): + assert sorted(ms.difference(env, tuple(dataset1), tuple(dataset2))) == expected[1] + + +@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA) +class TestSymmetricDifference: + def test_unhashable(self, dataset1, dataset2, expected): + assert sorted(ms.symmetric_difference(env, list(dataset1), list(dataset2))) == expected[2] + + def test_hashable(self, dataset1, dataset2, expected): + assert sorted(ms.symmetric_difference(env, tuple(dataset1), tuple(dataset2))) == expected[2] + + +class TestMin: + def test_min(self): + assert ms.min((1, 2)) == 1 + assert ms.min((2, 1)) == 1 + assert ms.min(('p', 'a', 'w', 'b', 'p')) == 'a' + + +class TestMax: + def test_max(self): + assert ms.max((1, 2)) == 2 + assert ms.max((2, 1)) == 2 + assert ms.max(('p', 'a', 'w', 'b', 'p')) == 'w' + + +class TestLogarithm: + def test_log_non_number(self): + # Message changed in python3.6 + with pytest.raises(AnsibleFilterTypeError, match='log\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'): + ms.logarithm('a') + with pytest.raises(AnsibleFilterTypeError, match='log\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'): + ms.logarithm(10, base='a') + + def test_log_ten(self): + assert ms.logarithm(10, 10) == 1.0 + assert ms.logarithm(69, 10) * 1000 // 1 == 1838 + + def test_log_natural(self): + assert ms.logarithm(69) * 1000 // 1 == 4234 + + def test_log_two(self): + assert ms.logarithm(69, 2) * 1000 // 1 == 6108 + + +class TestPower: + def test_power_non_number(self): + # Message changed in python3.6 + with pytest.raises(AnsibleFilterTypeError, match='pow\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'): + ms.power('a', 10) + + with pytest.raises(AnsibleFilterTypeError, match='pow\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'): + ms.power(10, 'a') + + def test_power_squared(self): + assert ms.power(10, 2) == 100 + + def test_power_cubed(self): + assert ms.power(10, 3) == 1000 + + +class TestInversePower: + def test_root_non_number(self): + # Messages differed in python-2.6, python-2.7-3.5, and python-3.6+ + with pytest.raises(AnsibleFilterTypeError, match="root\\(\\) can only be used on numbers:" + " (invalid literal for float\\(\\): a" + "|could not convert string to float: a" + "|could not convert string to float: 'a')"): + ms.inversepower(10, 'a') + + with pytest.raises(AnsibleFilterTypeError, match="root\\(\\) can only be used on numbers: (a float is required|must be real number, not str)"): + ms.inversepower('a', 10) + + def test_square_root(self): + assert ms.inversepower(100) == 10 + assert ms.inversepower(100, 2) == 10 + + def test_cube_root(self): + assert ms.inversepower(27, 3) == 3 + + +class TestRekeyOnMember(): + # (Input data structure, member to rekey on, expected return) + VALID_ENTRIES = ( + ([{"proto": "eigrp", "state": "enabled"}, {"proto": "ospf", "state": "enabled"}], + 'proto', + {'eigrp': {'state': 'enabled', 'proto': 'eigrp'}, 'ospf': {'state': 'enabled', 'proto': 'ospf'}}), + ({'eigrp': {"proto": "eigrp", "state": "enabled"}, 'ospf': {"proto": "ospf", "state": "enabled"}}, + 'proto', + {'eigrp': {'state': 'enabled', 'proto': 'eigrp'}, 'ospf': {'state': 'enabled', 'proto': 'ospf'}}), + ) + + # (Input data structure, member to rekey on, expected error message) + INVALID_ENTRIES = ( + # Fail when key is not found + (AnsibleFilterError, [{"proto": "eigrp", "state": "enabled"}], 'invalid_key', "Key invalid_key was not found"), + (AnsibleFilterError, {"eigrp": {"proto": "eigrp", "state": "enabled"}}, 'invalid_key', "Key invalid_key was not found"), + # Fail when key is duplicated + (AnsibleFilterError, [{"proto": "eigrp"}, {"proto": "ospf"}, {"proto": "ospf"}], + 'proto', 'Key ospf is not unique, cannot correctly turn into dict'), + # Fail when value is not a dict + (AnsibleFilterTypeError, ["string"], 'proto', "List item is not a valid dict"), + (AnsibleFilterTypeError, [123], 'proto', "List item is not a valid dict"), + (AnsibleFilterTypeError, [[{'proto': 1}]], 'proto', "List item is not a valid dict"), + # Fail when we do not send a dict or list + (AnsibleFilterTypeError, "string", 'proto', "Type is not a valid list, set, or dict"), + (AnsibleFilterTypeError, 123, 'proto', "Type is not a valid list, set, or dict"), + ) + + @pytest.mark.parametrize("list_original, key, expected", VALID_ENTRIES) + def test_rekey_on_member_success(self, list_original, key, expected): + assert ms.rekey_on_member(list_original, key) == expected + + @pytest.mark.parametrize("expected_exception_type, list_original, key, expected", INVALID_ENTRIES) + def test_fail_rekey_on_member(self, expected_exception_type, list_original, key, expected): + with pytest.raises(expected_exception_type) as err: + ms.rekey_on_member(list_original, key) + + assert err.value.message == expected + + def test_duplicate_strategy_overwrite(self): + list_original = ({'proto': 'eigrp', 'id': 1}, {'proto': 'ospf', 'id': 2}, {'proto': 'eigrp', 'id': 3}) + expected = {'eigrp': {'proto': 'eigrp', 'id': 3}, 'ospf': {'proto': 'ospf', 'id': 2}} + assert ms.rekey_on_member(list_original, 'proto', duplicates='overwrite') == expected diff --git a/test/units/plugins/inventory/__init__.py b/test/units/plugins/inventory/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/inventory/__init__.py diff --git a/test/units/plugins/inventory/test_constructed.py b/test/units/plugins/inventory/test_constructed.py new file mode 100644 index 00000000..6d521982 --- /dev/null +++ b/test/units/plugins/inventory/test_constructed.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Alan Rominger <arominge@redhat.net> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible.errors import AnsibleParserError +from ansible.plugins.inventory.constructed import InventoryModule +from ansible.inventory.data import InventoryData +from ansible.template import Templar + + +@pytest.fixture() +def inventory_module(): + r = InventoryModule() + r.inventory = InventoryData() + r.templar = Templar(None) + return r + + +def test_group_by_value_only(inventory_module): + inventory_module.inventory.add_host('foohost') + inventory_module.inventory.set_variable('foohost', 'bar', 'my_group_name') + host = inventory_module.inventory.get_host('foohost') + keyed_groups = [ + { + 'prefix': '', + 'separator': '', + 'key': 'bar' + } + ] + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=False + ) + assert 'my_group_name' in inventory_module.inventory.groups + group = inventory_module.inventory.groups['my_group_name'] + assert group.hosts == [host] + + +def test_keyed_group_separator(inventory_module): + inventory_module.inventory.add_host('farm') + inventory_module.inventory.set_variable('farm', 'farmer', 'mcdonald') + inventory_module.inventory.set_variable('farm', 'barn', {'cow': 'betsy'}) + host = inventory_module.inventory.get_host('farm') + keyed_groups = [ + { + 'prefix': 'farmer', + 'separator': '_old_', + 'key': 'farmer' + }, + { + 'separator': 'mmmmmmmmmm', + 'key': 'barn' + } + ] + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=False + ) + for group_name in ('farmer_old_mcdonald', 'mmmmmmmmmmcowmmmmmmmmmmbetsy'): + assert group_name in inventory_module.inventory.groups + group = inventory_module.inventory.groups[group_name] + assert group.hosts == [host] + + +def test_keyed_group_empty_construction(inventory_module): + inventory_module.inventory.add_host('farm') + inventory_module.inventory.set_variable('farm', 'barn', {}) + host = inventory_module.inventory.get_host('farm') + keyed_groups = [ + { + 'separator': 'mmmmmmmmmm', + 'key': 'barn' + } + ] + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=True + ) + assert host.groups == [] + + +def test_keyed_group_host_confusion(inventory_module): + inventory_module.inventory.add_host('cow') + inventory_module.inventory.add_group('cow') + host = inventory_module.inventory.get_host('cow') + host.vars['species'] = 'cow' + keyed_groups = [ + { + 'separator': '', + 'prefix': '', + 'key': 'species' + } + ] + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=True + ) + group = inventory_module.inventory.groups['cow'] + # group cow has host of cow + assert group.hosts == [host] + + +def test_keyed_parent_groups(inventory_module): + inventory_module.inventory.add_host('web1') + inventory_module.inventory.add_host('web2') + inventory_module.inventory.set_variable('web1', 'region', 'japan') + inventory_module.inventory.set_variable('web2', 'region', 'japan') + host1 = inventory_module.inventory.get_host('web1') + host2 = inventory_module.inventory.get_host('web2') + keyed_groups = [ + { + 'prefix': 'region', + 'key': 'region', + 'parent_group': 'region_list' + } + ] + for host in [host1, host2]: + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=False + ) + assert 'region_japan' in inventory_module.inventory.groups + assert 'region_list' in inventory_module.inventory.groups + region_group = inventory_module.inventory.groups['region_japan'] + all_regions = inventory_module.inventory.groups['region_list'] + assert all_regions.child_groups == [region_group] + assert region_group.hosts == [host1, host2] + + +def test_parent_group_templating(inventory_module): + inventory_module.inventory.add_host('cow') + inventory_module.inventory.set_variable('cow', 'sound', 'mmmmmmmmmm') + inventory_module.inventory.set_variable('cow', 'nickname', 'betsy') + host = inventory_module.inventory.get_host('cow') + keyed_groups = [ + { + 'key': 'sound', + 'prefix': 'sound', + 'parent_group': '{{ nickname }}' + }, + { + 'key': 'nickname', + 'prefix': '', + 'separator': '', + 'parent_group': 'nickname' # statically-named parent group, conflicting with hostvar + }, + { + 'key': 'nickname', + 'separator': '', + 'parent_group': '{{ location | default("field") }}' + } + ] + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=True + ) + # first keyed group, "betsy" is a parent group name dynamically generated + betsys_group = inventory_module.inventory.groups['betsy'] + assert [child.name for child in betsys_group.child_groups] == ['sound_mmmmmmmmmm'] + # second keyed group, "nickname" is a statically-named root group + nicknames_group = inventory_module.inventory.groups['nickname'] + assert [child.name for child in nicknames_group.child_groups] == ['betsy'] + # second keyed group actually generated the parent group of the first keyed group + # assert that these are, in fact, the same object + assert nicknames_group.child_groups[0] == betsys_group + # second keyed group has two parents + locations_group = inventory_module.inventory.groups['field'] + assert [child.name for child in locations_group.child_groups] == ['betsy'] + + +def test_parent_group_templating_error(inventory_module): + inventory_module.inventory.add_host('cow') + inventory_module.inventory.set_variable('cow', 'nickname', 'betsy') + host = inventory_module.inventory.get_host('cow') + keyed_groups = [ + { + 'key': 'nickname', + 'separator': '', + 'parent_group': '{{ location.barn-yard }}' + } + ] + with pytest.raises(AnsibleParserError) as err_message: + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=True + ) + assert 'Could not generate parent group' in err_message + # invalid parent group did not raise an exception with strict=False + inventory_module._add_host_to_keyed_groups( + keyed_groups, host.vars, host.name, strict=False + ) + # assert group was never added with invalid parent + assert 'betsy' not in inventory_module.inventory.groups diff --git a/test/units/plugins/inventory/test_inventory.py b/test/units/plugins/inventory/test_inventory.py new file mode 100644 index 00000000..66b5ec37 --- /dev/null +++ b/test/units/plugins/inventory/test_inventory.py @@ -0,0 +1,207 @@ +# Copyright 2015 Abhijit Menon-Sen <ams@2ndQuadrant.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import string +import textwrap + +from ansible import constants as C +from units.compat import mock +from units.compat import unittest +from ansible.module_utils.six import string_types +from ansible.module_utils._text import to_text +from units.mock.path import mock_unfrackpath_noop + +from ansible.inventory.manager import InventoryManager, split_host_pattern + +from units.mock.loader import DictDataLoader + + +class TestInventory(unittest.TestCase): + + patterns = { + 'a': ['a'], + 'a, b': ['a', 'b'], + 'a , b': ['a', 'b'], + ' a,b ,c[1:2] ': ['a', 'b', 'c[1:2]'], + '9a01:7f8:191:7701::9': ['9a01:7f8:191:7701::9'], + '9a01:7f8:191:7701::9,9a01:7f8:191:7701::9': ['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9'], + '9a01:7f8:191:7701::9,9a01:7f8:191:7701::9,foo': ['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9', 'foo'], + 'foo[1:2]': ['foo[1:2]'], + 'a::b': ['a::b'], + 'a:b': ['a', 'b'], + ' a : b ': ['a', 'b'], + 'foo:bar:baz[1:2]': ['foo', 'bar', 'baz[1:2]'], + 'a,,b': ['a', 'b'], + 'a, ,b,,c, ,': ['a', 'b', 'c'], + ',': [], + '': [], + } + + pattern_lists = [ + [['a'], ['a']], + [['a', 'b'], ['a', 'b']], + [['a, b'], ['a', 'b']], + [['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9,foo'], + ['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9', 'foo']] + ] + + # pattern_string: [ ('base_pattern', (a,b)), ['x','y','z'] ] + # a,b are the bounds of the subscript; x..z are the results of the subscript + # when applied to string.ascii_letters. + + subscripts = { + 'a': [('a', None), list(string.ascii_letters)], + 'a[0]': [('a', (0, None)), ['a']], + 'a[1]': [('a', (1, None)), ['b']], + 'a[2:3]': [('a', (2, 3)), ['c', 'd']], + 'a[-1]': [('a', (-1, None)), ['Z']], + 'a[-2]': [('a', (-2, None)), ['Y']], + 'a[48:]': [('a', (48, -1)), ['W', 'X', 'Y', 'Z']], + 'a[49:]': [('a', (49, -1)), ['X', 'Y', 'Z']], + 'a[1:]': [('a', (1, -1)), list(string.ascii_letters[1:])], + } + + ranges_to_expand = { + 'a[1:2]': ['a1', 'a2'], + 'a[1:10:2]': ['a1', 'a3', 'a5', 'a7', 'a9'], + 'a[a:b]': ['aa', 'ab'], + 'a[a:i:3]': ['aa', 'ad', 'ag'], + 'a[a:b][c:d]': ['aac', 'aad', 'abc', 'abd'], + 'a[0:1][2:3]': ['a02', 'a03', 'a12', 'a13'], + 'a[a:b][2:3]': ['aa2', 'aa3', 'ab2', 'ab3'], + } + + def setUp(self): + fake_loader = DictDataLoader({}) + + self.i = InventoryManager(loader=fake_loader, sources=[None]) + + def test_split_patterns(self): + + for p in self.patterns: + r = self.patterns[p] + self.assertEqual(r, split_host_pattern(p)) + + for p, r in self.pattern_lists: + self.assertEqual(r, split_host_pattern(p)) + + def test_ranges(self): + + for s in self.subscripts: + r = self.subscripts[s] + self.assertEqual(r[0], self.i._split_subscript(s)) + self.assertEqual( + r[1], + self.i._apply_subscript( + list(string.ascii_letters), + r[0][1] + ) + ) + + +class TestInventoryPlugins(unittest.TestCase): + + def test_empty_inventory(self): + inventory = self._get_inventory('') + + self.assertIn('all', inventory.groups) + self.assertIn('ungrouped', inventory.groups) + self.assertFalse(inventory.groups['all'].get_hosts()) + self.assertFalse(inventory.groups['ungrouped'].get_hosts()) + + def test_ini(self): + self._test_default_groups(""" + host1 + host2 + host3 + [servers] + host3 + host4 + host5 + """) + + def test_ini_explicit_ungrouped(self): + self._test_default_groups(""" + [ungrouped] + host1 + host2 + host3 + [servers] + host3 + host4 + host5 + """) + + def test_ini_variables_stringify(self): + values = ['string', 'no', 'No', 'false', 'FALSE', [], False, 0] + + inventory_content = "host1 " + inventory_content += ' '.join(['var%s=%s' % (i, to_text(x)) for i, x in enumerate(values)]) + inventory = self._get_inventory(inventory_content) + + variables = inventory.get_host('host1').vars + for i in range(len(values)): + if isinstance(values[i], string_types): + self.assertIsInstance(variables['var%s' % i], string_types) + else: + self.assertIsInstance(variables['var%s' % i], type(values[i])) + + @mock.patch('ansible.inventory.manager.unfrackpath', mock_unfrackpath_noop) + @mock.patch('os.path.exists', lambda x: True) + @mock.patch('os.access', lambda x, y: True) + def test_yaml_inventory(self, filename="test.yaml"): + inventory_content = {filename: textwrap.dedent("""\ + --- + all: + hosts: + test1: + test2: + """)} + C.INVENTORY_ENABLED = ['yaml'] + fake_loader = DictDataLoader(inventory_content) + im = InventoryManager(loader=fake_loader, sources=filename) + self.assertTrue(im._inventory.hosts) + self.assertIn('test1', im._inventory.hosts) + self.assertIn('test2', im._inventory.hosts) + self.assertIn(im._inventory.get_host('test1'), im._inventory.groups['all'].hosts) + self.assertIn(im._inventory.get_host('test2'), im._inventory.groups['all'].hosts) + self.assertEqual(len(im._inventory.groups['all'].hosts), 2) + self.assertIn(im._inventory.get_host('test1'), im._inventory.groups['ungrouped'].hosts) + self.assertIn(im._inventory.get_host('test2'), im._inventory.groups['ungrouped'].hosts) + self.assertEqual(len(im._inventory.groups['ungrouped'].hosts), 2) + + def _get_inventory(self, inventory_content): + + fake_loader = DictDataLoader({__file__: inventory_content}) + + return InventoryManager(loader=fake_loader, sources=[__file__]) + + def _test_default_groups(self, inventory_content): + inventory = self._get_inventory(inventory_content) + + self.assertIn('all', inventory.groups) + self.assertIn('ungrouped', inventory.groups) + all_hosts = set(host.name for host in inventory.groups['all'].get_hosts()) + self.assertEqual(set(['host1', 'host2', 'host3', 'host4', 'host5']), all_hosts) + ungrouped_hosts = set(host.name for host in inventory.groups['ungrouped'].get_hosts()) + self.assertEqual(set(['host1', 'host2']), ungrouped_hosts) + servers_hosts = set(host.name for host in inventory.groups['servers'].get_hosts()) + self.assertEqual(set(['host3', 'host4', 'host5']), servers_hosts) diff --git a/test/units/plugins/inventory/test_script.py b/test/units/plugins/inventory/test_script.py new file mode 100644 index 00000000..5f054813 --- /dev/null +++ b/test/units/plugins/inventory/test_script.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +# Copyright 2017 Chris Meyers <cmeyers@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible import constants as C +from ansible.errors import AnsibleError +from ansible.plugins.loader import PluginLoader +from units.compat import mock +from units.compat import unittest +from ansible.module_utils._text import to_bytes, to_native + + +class TestInventoryModule(unittest.TestCase): + + def setUp(self): + + class Inventory(): + cache = dict() + + class PopenResult(): + returncode = 0 + stdout = b"" + stderr = b"" + + def communicate(self): + return (self.stdout, self.stderr) + + self.popen_result = PopenResult() + self.inventory = Inventory() + self.loader = mock.MagicMock() + self.loader.load = mock.MagicMock() + + inv_loader = PluginLoader('InventoryModule', 'ansible.plugins.inventory', C.DEFAULT_INVENTORY_PLUGIN_PATH, 'inventory_plugins') + self.inventory_module = inv_loader.get('script') + self.inventory_module.set_options() + + def register_patch(name): + patcher = mock.patch(name) + self.addCleanup(patcher.stop) + return patcher.start() + + self.popen = register_patch('subprocess.Popen') + self.popen.return_value = self.popen_result + + self.BaseInventoryPlugin = register_patch('ansible.plugins.inventory.BaseInventoryPlugin') + self.BaseInventoryPlugin.get_cache_prefix.return_value = 'abc123' + + def test_parse_subprocess_path_not_found_fail(self): + self.popen.side_effect = OSError("dummy text") + + with pytest.raises(AnsibleError) as e: + self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') + assert e.value.message == "problem running /foo/bar/foobar.py --list (dummy text)" + + def test_parse_subprocess_err_code_fail(self): + self.popen_result.stdout = to_bytes(u"fooébar", errors='surrogate_escape') + self.popen_result.stderr = to_bytes(u"dummyédata") + + self.popen_result.returncode = 1 + + with pytest.raises(AnsibleError) as e: + self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') + assert e.value.message == to_native("Inventory script (/foo/bar/foobar.py) had an execution error: " + "dummyédata\n ") + + def test_parse_utf8_fail(self): + self.popen_result.returncode = 0 + self.popen_result.stderr = to_bytes("dummyédata") + self.loader.load.side_effect = TypeError('obj must be string') + + with pytest.raises(AnsibleError) as e: + self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') + assert e.value.message == to_native("failed to parse executable inventory script results from " + "/foo/bar/foobar.py: obj must be string\ndummyédata\n") + + def test_parse_dict_fail(self): + self.popen_result.returncode = 0 + self.popen_result.stderr = to_bytes("dummyédata") + self.loader.load.return_value = 'i am not a dict' + + with pytest.raises(AnsibleError) as e: + self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') + assert e.value.message == to_native("failed to parse executable inventory script results from " + "/foo/bar/foobar.py: needs to be a json dict\ndummyédata\n") diff --git a/test/units/plugins/loader_fixtures/__init__.py b/test/units/plugins/loader_fixtures/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/loader_fixtures/__init__.py diff --git a/test/units/plugins/loader_fixtures/import_fixture.py b/test/units/plugins/loader_fixtures/import_fixture.py new file mode 100644 index 00000000..81127332 --- /dev/null +++ b/test/units/plugins/loader_fixtures/import_fixture.py @@ -0,0 +1,9 @@ +# Nothing to see here, this file is just empty to support a imp.load_source +# without doing anything +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +class test: + def __init__(self, *args, **kwargs): + pass diff --git a/test/units/plugins/lookup/__init__.py b/test/units/plugins/lookup/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/lookup/__init__.py diff --git a/test/units/plugins/lookup/test_env.py b/test/units/plugins/lookup/test_env.py new file mode 100644 index 00000000..5d9713fe --- /dev/null +++ b/test/units/plugins/lookup/test_env.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2019, Abhay Kadam <abhaykadam88@gmail.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible.plugins.loader import lookup_loader + + +@pytest.mark.parametrize('env_var,exp_value', [ + ('foo', 'bar'), + ('equation', 'a=b*100') +]) +def test_env_var_value(monkeypatch, env_var, exp_value): + monkeypatch.setattr('ansible.utils.py3compat.environ.get', lambda x, y: exp_value) + + env_lookup = lookup_loader.get('env') + retval = env_lookup.run([env_var], None) + assert retval == [exp_value] + + +@pytest.mark.parametrize('env_var,exp_value', [ + ('simple_var', 'alpha-β-gamma'), + ('the_var', 'ãnˈsiβle') +]) +def test_utf8_env_var_value(monkeypatch, env_var, exp_value): + monkeypatch.setattr('ansible.utils.py3compat.environ.get', lambda x, y: exp_value) + + env_lookup = lookup_loader.get('env') + retval = env_lookup.run([env_var], None) + assert retval == [exp_value] diff --git a/test/units/plugins/lookup/test_ini.py b/test/units/plugins/lookup/test_ini.py new file mode 100644 index 00000000..adf2bac2 --- /dev/null +++ b/test/units/plugins/lookup/test_ini.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from units.compat import unittest +from ansible.plugins.lookup.ini import _parse_params + + +class TestINILookup(unittest.TestCase): + + # Currently there isn't a new-style + old_style_params_data = ( + # Simple case + dict( + term=u'keyA section=sectionA file=/path/to/file', + expected=[u'file=/path/to/file', u'keyA', u'section=sectionA'], + ), + dict( + term=u'keyB section=sectionB with space file=/path/with/embedded spaces and/file', + expected=[u'file=/path/with/embedded spaces and/file', u'keyB', u'section=sectionB with space'], + ), + dict( + term=u'keyC section=sectionC file=/path/with/equals/cn=com.ansible', + expected=[u'file=/path/with/equals/cn=com.ansible', u'keyC', u'section=sectionC'], + ), + dict( + term=u'keyD section=sectionD file=/path/with space and/equals/cn=com.ansible', + expected=[u'file=/path/with space and/equals/cn=com.ansible', u'keyD', u'section=sectionD'], + ), + dict( + term=u'keyE section=sectionE file=/path/with/unicode/くらとみ/file', + expected=[u'file=/path/with/unicode/くらとみ/file', u'keyE', u'section=sectionE'], + ), + dict( + term=u'keyF section=sectionF file=/path/with/utf 8 and spaces/くらとみ/file', + expected=[u'file=/path/with/utf 8 and spaces/くらとみ/file', u'keyF', u'section=sectionF'], + ), + ) + + def test_parse_parameters(self): + for testcase in self.old_style_params_data: + # print(testcase) + params = _parse_params(testcase['term']) + params.sort() + self.assertEqual(params, testcase['expected']) diff --git a/test/units/plugins/lookup/test_password.py b/test/units/plugins/lookup/test_password.py new file mode 100644 index 00000000..9871f4ab --- /dev/null +++ b/test/units/plugins/lookup/test_password.py @@ -0,0 +1,501 @@ +# -*- coding: utf-8 -*- +# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +try: + import passlib + from passlib.handlers import pbkdf2 +except ImportError: + passlib = None + pbkdf2 = None + +import pytest + +from units.mock.loader import DictDataLoader + +from units.compat import unittest +from units.compat.mock import mock_open, patch +from ansible.errors import AnsibleError +from ansible.module_utils.six import text_type +from ansible.module_utils.six.moves import builtins +from ansible.module_utils._text import to_bytes +from ansible.plugins.loader import PluginLoader +from ansible.plugins.lookup import password + + +DEFAULT_CHARS = sorted([u'ascii_letters', u'digits', u".,:-_"]) +DEFAULT_CANDIDATE_CHARS = u'.,:-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' + +# Currently there isn't a new-style +old_style_params_data = ( + # Simple case + dict( + term=u'/path/to/file', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + + # Special characters in path + dict( + term=u'/path/with/embedded spaces and/file', + filename=u'/path/with/embedded spaces and/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + dict( + term=u'/path/with/equals/cn=com.ansible', + filename=u'/path/with/equals/cn=com.ansible', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + dict( + term=u'/path/with/unicode/くらとみ/file', + filename=u'/path/with/unicode/くらとみ/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + # Mix several special chars + dict( + term=u'/path/with/utf 8 and spaces/くらとみ/file', + filename=u'/path/with/utf 8 and spaces/くらとみ/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + dict( + term=u'/path/with/encoding=unicode/くらとみ/file', + filename=u'/path/with/encoding=unicode/くらとみ/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + dict( + term=u'/path/with/encoding=unicode/くらとみ/and spaces file', + filename=u'/path/with/encoding=unicode/くらとみ/and spaces file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + + # Simple parameters + dict( + term=u'/path/to/file length=42', + filename=u'/path/to/file', + params=dict(length=42, encrypt=None, chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + dict( + term=u'/path/to/file encrypt=pbkdf2_sha256', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt='pbkdf2_sha256', chars=DEFAULT_CHARS), + candidate_chars=DEFAULT_CANDIDATE_CHARS, + ), + dict( + term=u'/path/to/file chars=abcdefghijklmnop', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u'abcdefghijklmnop']), + candidate_chars=u'abcdefghijklmnop', + ), + dict( + term=u'/path/to/file chars=digits,abc,def', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'digits', u'abc', u'def'])), + candidate_chars=u'abcdef0123456789', + ), + + # Including comma in chars + dict( + term=u'/path/to/file chars=abcdefghijklmnop,,digits', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'abcdefghijklmnop', u',', u'digits'])), + candidate_chars=u',abcdefghijklmnop0123456789', + ), + dict( + term=u'/path/to/file chars=,,', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u',']), + candidate_chars=u',', + ), + + # Including = in chars + dict( + term=u'/path/to/file chars=digits,=,,', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'digits', u'=', u','])), + candidate_chars=u',=0123456789', + ), + dict( + term=u'/path/to/file chars=digits,abc=def', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'digits', u'abc=def'])), + candidate_chars=u'abc=def0123456789', + ), + + # Including unicode in chars + dict( + term=u'/path/to/file chars=digits,くらとみ,,', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'digits', u'くらとみ', u','])), + candidate_chars=u',0123456789くらとみ', + ), + # Including only unicode in chars + dict( + term=u'/path/to/file chars=くらとみ', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'くらとみ'])), + candidate_chars=u'くらとみ', + ), + + # Include ':' in path + dict( + term=u'/path/to/file_with:colon chars=ascii_letters,digits', + filename=u'/path/to/file_with:colon', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=sorted([u'ascii_letters', u'digits'])), + candidate_chars=u'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', + ), + + # Including special chars in both path and chars + # Special characters in path + dict( + term=u'/path/with/embedded spaces and/file chars=abc=def', + filename=u'/path/with/embedded spaces and/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u'abc=def']), + candidate_chars=u'abc=def', + ), + dict( + term=u'/path/with/equals/cn=com.ansible chars=abc=def', + filename=u'/path/with/equals/cn=com.ansible', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u'abc=def']), + candidate_chars=u'abc=def', + ), + dict( + term=u'/path/with/unicode/くらとみ/file chars=くらとみ', + filename=u'/path/with/unicode/くらとみ/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u'くらとみ']), + candidate_chars=u'くらとみ', + ), +) + + +class TestParseParameters(unittest.TestCase): + def test(self): + for testcase in old_style_params_data: + filename, params = password._parse_parameters(testcase['term']) + params['chars'].sort() + self.assertEqual(filename, testcase['filename']) + self.assertEqual(params, testcase['params']) + + def test_unrecognized_value(self): + testcase = dict(term=u'/path/to/file chars=くらとみi sdfsdf', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u'くらとみ']), + candidate_chars=u'くらとみ') + self.assertRaises(AnsibleError, password._parse_parameters, testcase['term']) + + def test_invalid_params(self): + testcase = dict(term=u'/path/to/file chars=くらとみi somethign_invalid=123', + filename=u'/path/to/file', + params=dict(length=password.DEFAULT_LENGTH, encrypt=None, chars=[u'くらとみ']), + candidate_chars=u'くらとみ') + self.assertRaises(AnsibleError, password._parse_parameters, testcase['term']) + + +class TestReadPasswordFile(unittest.TestCase): + def setUp(self): + self.os_path_exists = password.os.path.exists + + def tearDown(self): + password.os.path.exists = self.os_path_exists + + def test_no_password_file(self): + password.os.path.exists = lambda x: False + self.assertEqual(password._read_password_file(b'/nonexistent'), None) + + def test_with_password_file(self): + password.os.path.exists = lambda x: True + with patch.object(builtins, 'open', mock_open(read_data=b'Testing\n')) as m: + self.assertEqual(password._read_password_file(b'/etc/motd'), u'Testing') + + +class TestGenCandidateChars(unittest.TestCase): + def _assert_gen_candidate_chars(self, testcase): + expected_candidate_chars = testcase['candidate_chars'] + params = testcase['params'] + chars_spec = params['chars'] + res = password._gen_candidate_chars(chars_spec) + self.assertEqual(res, expected_candidate_chars) + + def test_gen_candidate_chars(self): + for testcase in old_style_params_data: + self._assert_gen_candidate_chars(testcase) + + +class TestRandomPassword(unittest.TestCase): + def _assert_valid_chars(self, res, chars): + for res_char in res: + self.assertIn(res_char, chars) + + def test_default(self): + res = password.random_password() + self.assertEqual(len(res), password.DEFAULT_LENGTH) + self.assertTrue(isinstance(res, text_type)) + self._assert_valid_chars(res, DEFAULT_CANDIDATE_CHARS) + + def test_zero_length(self): + res = password.random_password(length=0) + self.assertEqual(len(res), 0) + self.assertTrue(isinstance(res, text_type)) + self._assert_valid_chars(res, u',') + + def test_just_a_common(self): + res = password.random_password(length=1, chars=u',') + self.assertEqual(len(res), 1) + self.assertEqual(res, u',') + + def test_free_will(self): + # A Rush and Spinal Tap reference twofer + res = password.random_password(length=11, chars=u'a') + self.assertEqual(len(res), 11) + self.assertEqual(res, 'aaaaaaaaaaa') + self._assert_valid_chars(res, u'a') + + def test_unicode(self): + res = password.random_password(length=11, chars=u'くらとみ') + self._assert_valid_chars(res, u'くらとみ') + self.assertEqual(len(res), 11) + + def test_gen_password(self): + for testcase in old_style_params_data: + params = testcase['params'] + candidate_chars = testcase['candidate_chars'] + params_chars_spec = password._gen_candidate_chars(params['chars']) + password_string = password.random_password(length=params['length'], + chars=params_chars_spec) + self.assertEqual(len(password_string), + params['length'], + msg='generated password=%s has length (%s) instead of expected length (%s)' % + (password_string, len(password_string), params['length'])) + + for char in password_string: + self.assertIn(char, candidate_chars, + msg='%s not found in %s from chars spect %s' % + (char, candidate_chars, params['chars'])) + + +class TestParseContent(unittest.TestCase): + def test_empty_password_file(self): + plaintext_password, salt = password._parse_content(u'') + self.assertEqual(plaintext_password, u'') + self.assertEqual(salt, None) + + def test(self): + expected_content = u'12345678' + file_content = expected_content + plaintext_password, salt = password._parse_content(file_content) + self.assertEqual(plaintext_password, expected_content) + self.assertEqual(salt, None) + + def test_with_salt(self): + expected_content = u'12345678 salt=87654321' + file_content = expected_content + plaintext_password, salt = password._parse_content(file_content) + self.assertEqual(plaintext_password, u'12345678') + self.assertEqual(salt, u'87654321') + + +class TestFormatContent(unittest.TestCase): + def test_no_encrypt(self): + self.assertEqual( + password._format_content(password=u'hunter42', + salt=u'87654321', + encrypt=False), + u'hunter42 salt=87654321') + + def test_no_encrypt_no_salt(self): + self.assertEqual( + password._format_content(password=u'hunter42', + salt=None, + encrypt=None), + u'hunter42') + + def test_encrypt(self): + self.assertEqual( + password._format_content(password=u'hunter42', + salt=u'87654321', + encrypt='pbkdf2_sha256'), + u'hunter42 salt=87654321') + + def test_encrypt_no_salt(self): + self.assertRaises(AssertionError, password._format_content, u'hunter42', None, 'pbkdf2_sha256') + + +class TestWritePasswordFile(unittest.TestCase): + def setUp(self): + self.makedirs_safe = password.makedirs_safe + self.os_chmod = password.os.chmod + password.makedirs_safe = lambda path, mode: None + password.os.chmod = lambda path, mode: None + + def tearDown(self): + password.makedirs_safe = self.makedirs_safe + password.os.chmod = self.os_chmod + + def test_content_written(self): + + with patch.object(builtins, 'open', mock_open()) as m: + password._write_password_file(b'/this/is/a/test/caf\xc3\xa9', u'Testing Café') + + m.assert_called_once_with(b'/this/is/a/test/caf\xc3\xa9', 'wb') + m().write.assert_called_once_with(u'Testing Café\n'.encode('utf-8')) + + +class BaseTestLookupModule(unittest.TestCase): + def setUp(self): + self.fake_loader = DictDataLoader({'/path/to/somewhere': 'sdfsdf'}) + self.password_lookup = password.LookupModule(loader=self.fake_loader) + self.os_path_exists = password.os.path.exists + self.os_open = password.os.open + password.os.open = lambda path, flag: None + self.os_close = password.os.close + password.os.close = lambda fd: None + self.os_remove = password.os.remove + password.os.remove = lambda path: None + self.makedirs_safe = password.makedirs_safe + password.makedirs_safe = lambda path, mode: None + + def tearDown(self): + password.os.path.exists = self.os_path_exists + password.os.open = self.os_open + password.os.close = self.os_close + password.os.remove = self.os_remove + password.makedirs_safe = self.makedirs_safe + + +class TestLookupModuleWithoutPasslib(BaseTestLookupModule): + @patch.object(PluginLoader, '_get_paths') + @patch('ansible.plugins.lookup.password._write_password_file') + def test_no_encrypt(self, mock_get_paths, mock_write_file): + mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three'] + + results = self.password_lookup.run([u'/path/to/somewhere'], None) + + # FIXME: assert something useful + for result in results: + assert len(result) == password.DEFAULT_LENGTH + assert isinstance(result, text_type) + + @patch.object(PluginLoader, '_get_paths') + @patch('ansible.plugins.lookup.password._write_password_file') + def test_password_already_created_no_encrypt(self, mock_get_paths, mock_write_file): + mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three'] + password.os.path.exists = lambda x: x == to_bytes('/path/to/somewhere') + + with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m: + results = self.password_lookup.run([u'/path/to/somewhere chars=anything'], None) + + for result in results: + self.assertEqual(result, u'hunter42') + + @patch.object(PluginLoader, '_get_paths') + @patch('ansible.plugins.lookup.password._write_password_file') + def test_only_a(self, mock_get_paths, mock_write_file): + mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three'] + + results = self.password_lookup.run([u'/path/to/somewhere chars=a'], None) + for result in results: + self.assertEqual(result, u'a' * password.DEFAULT_LENGTH) + + @patch('time.sleep') + def test_lock_been_held(self, mock_sleep): + # pretend the lock file is here + password.os.path.exists = lambda x: True + try: + with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m: + # should timeout here + results = self.password_lookup.run([u'/path/to/somewhere chars=anything'], None) + self.fail("Lookup didn't timeout when lock already been held") + except AnsibleError: + pass + + def test_lock_not_been_held(self): + # pretend now there is password file but no lock + password.os.path.exists = lambda x: x == to_bytes('/path/to/somewhere') + try: + with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m: + # should not timeout here + results = self.password_lookup.run([u'/path/to/somewhere chars=anything'], None) + except AnsibleError: + self.fail('Lookup timeouts when lock is free') + + for result in results: + self.assertEqual(result, u'hunter42') + + +@pytest.mark.skipif(passlib is None, reason='passlib must be installed to run these tests') +class TestLookupModuleWithPasslib(BaseTestLookupModule): + def setUp(self): + super(TestLookupModuleWithPasslib, self).setUp() + + # Different releases of passlib default to a different number of rounds + self.sha256 = passlib.registry.get_crypt_handler('pbkdf2_sha256') + sha256_for_tests = pbkdf2.create_pbkdf2_hash("sha256", 32, 20000) + passlib.registry.register_crypt_handler(sha256_for_tests, force=True) + + def tearDown(self): + super(TestLookupModuleWithPasslib, self).tearDown() + + passlib.registry.register_crypt_handler(self.sha256, force=True) + + @patch.object(PluginLoader, '_get_paths') + @patch('ansible.plugins.lookup.password._write_password_file') + def test_encrypt(self, mock_get_paths, mock_write_file): + mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three'] + + results = self.password_lookup.run([u'/path/to/somewhere encrypt=pbkdf2_sha256'], None) + + # pbkdf2 format plus hash + expected_password_length = 76 + + for result in results: + self.assertEqual(len(result), expected_password_length) + # result should have 5 parts split by '$' + str_parts = result.split('$', 5) + + # verify the result is parseable by the passlib + crypt_parts = passlib.hash.pbkdf2_sha256.parsehash(result) + + # verify it used the right algo type + self.assertEqual(str_parts[1], 'pbkdf2-sha256') + + self.assertEqual(len(str_parts), 5) + + # verify the string and parsehash agree on the number of rounds + self.assertEqual(int(str_parts[2]), crypt_parts['rounds']) + self.assertIsInstance(result, text_type) + + @patch.object(PluginLoader, '_get_paths') + @patch('ansible.plugins.lookup.password._write_password_file') + def test_password_already_created_encrypt(self, mock_get_paths, mock_write_file): + mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three'] + password.os.path.exists = lambda x: x == to_bytes('/path/to/somewhere') + + with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m: + results = self.password_lookup.run([u'/path/to/somewhere chars=anything encrypt=pbkdf2_sha256'], None) + for result in results: + self.assertEqual(result, u'$pbkdf2-sha256$20000$ODc2NTQzMjE$Uikde0cv0BKaRaAXMrUQB.zvG4GmnjClwjghwIRf2gU') diff --git a/test/units/plugins/shell/__init__.py b/test/units/plugins/shell/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/shell/__init__.py diff --git a/test/units/plugins/shell/test_cmd.py b/test/units/plugins/shell/test_cmd.py new file mode 100644 index 00000000..4c1a654b --- /dev/null +++ b/test/units/plugins/shell/test_cmd.py @@ -0,0 +1,19 @@ +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible.plugins.shell.cmd import ShellModule + + +@pytest.mark.parametrize('s, expected', [ + ['arg1', 'arg1'], + [None, '""'], + ['arg1 and 2', '^"arg1 and 2^"'], + ['malicious argument\\"&whoami', '^"malicious argument\\\\^"^&whoami^"'], + ['C:\\temp\\some ^%file% > nul', '^"C:\\temp\\some ^^^%file^% ^> nul^"'] +]) +def test_quote_args(s, expected): + cmd = ShellModule() + actual = cmd.quote(s) + assert actual == expected diff --git a/test/units/plugins/shell/test_powershell.py b/test/units/plugins/shell/test_powershell.py new file mode 100644 index 00000000..c94baabb --- /dev/null +++ b/test/units/plugins/shell/test_powershell.py @@ -0,0 +1,83 @@ +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.plugins.shell.powershell import _parse_clixml, ShellModule + + +def test_parse_clixml_empty(): + empty = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04"></Objs>' + expected = b'' + actual = _parse_clixml(empty) + assert actual == expected + + +def test_parse_clixml_with_progress(): + progress = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \ + b'<Obj S="progress" RefId="0"><TN RefId="0"><T>System.Management.Automation.PSCustomObject</T><T>System.Object</T></TN><MS>' \ + b'<I64 N="SourceId">1</I64><PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \ + b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj></Objs>' + expected = b'' + actual = _parse_clixml(progress) + assert actual == expected + + +def test_parse_clixml_single_stream(): + single_stream = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \ + b'<S S="Error">fake : The term \'fake\' is not recognized as the name of a cmdlet. Check _x000D__x000A_</S>' \ + b'<S S="Error">the spelling of the name, or if a path was included._x000D__x000A_</S>' \ + b'<S S="Error">At line:1 char:1_x000D__x000A_</S>' \ + b'<S S="Error">+ fake cmdlet_x000D__x000A_</S><S S="Error">+ ~~~~_x000D__x000A_</S>' \ + b'<S S="Error"> + CategoryInfo : ObjectNotFound: (fake:String) [], CommandNotFoundException_x000D__x000A_</S>' \ + b'<S S="Error"> + FullyQualifiedErrorId : CommandNotFoundException_x000D__x000A_</S><S S="Error"> _x000D__x000A_</S>' \ + b'</Objs>' + expected = b"fake : The term 'fake' is not recognized as the name of a cmdlet. Check \r\n" \ + b"the spelling of the name, or if a path was included.\r\n" \ + b"At line:1 char:1\r\n" \ + b"+ fake cmdlet\r\n" \ + b"+ ~~~~\r\n" \ + b" + CategoryInfo : ObjectNotFound: (fake:String) [], CommandNotFoundException\r\n" \ + b" + FullyQualifiedErrorId : CommandNotFoundException\r\n " + actual = _parse_clixml(single_stream) + assert actual == expected + + +def test_parse_clixml_multiple_streams(): + multiple_stream = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \ + b'<S S="Error">fake : The term \'fake\' is not recognized as the name of a cmdlet. Check _x000D__x000A_</S>' \ + b'<S S="Error">the spelling of the name, or if a path was included._x000D__x000A_</S>' \ + b'<S S="Error">At line:1 char:1_x000D__x000A_</S>' \ + b'<S S="Error">+ fake cmdlet_x000D__x000A_</S><S S="Error">+ ~~~~_x000D__x000A_</S>' \ + b'<S S="Error"> + CategoryInfo : ObjectNotFound: (fake:String) [], CommandNotFoundException_x000D__x000A_</S>' \ + b'<S S="Error"> + FullyQualifiedErrorId : CommandNotFoundException_x000D__x000A_</S><S S="Error"> _x000D__x000A_</S>' \ + b'<S S="Info">hi info</S>' \ + b'</Objs>' + expected = b"hi info" + actual = _parse_clixml(multiple_stream, stream="Info") + assert actual == expected + + +def test_parse_clixml_multiple_elements(): + multiple_elements = b'#< CLIXML\r\n#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \ + b'<Obj S="progress" RefId="0"><TN RefId="0"><T>System.Management.Automation.PSCustomObject</T><T>System.Object</T></TN><MS>' \ + b'<I64 N="SourceId">1</I64><PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \ + b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj>' \ + b'<S S="Error">Error 1</S></Objs>' \ + b'<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04"><Obj S="progress" RefId="0">' \ + b'<TN RefId="0"><T>System.Management.Automation.PSCustomObject</T><T>System.Object</T></TN><MS>' \ + b'<I64 N="SourceId">1</I64><PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \ + b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj>' \ + b'<Obj S="progress" RefId="1"><TNRef RefId="0" /><MS><I64 N="SourceId">2</I64>' \ + b'<PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \ + b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj>' \ + b'<S S="Error">Error 2</S></Objs>' + expected = b"Error 1\r\nError 2" + actual = _parse_clixml(multiple_elements) + assert actual == expected + + +def test_join_path_unc(): + pwsh = ShellModule() + unc_path_parts = ['\\\\host\\share\\dir1\\\\dir2\\', '\\dir3/dir4', 'dir5', 'dir6\\'] + expected = '\\\\host\\share\\dir1\\dir2\\dir3\\dir4\\dir5\\dir6' + actual = pwsh.join_path(*unc_path_parts) + assert actual == expected diff --git a/test/units/plugins/strategy/__init__.py b/test/units/plugins/strategy/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/units/plugins/strategy/__init__.py diff --git a/test/units/plugins/strategy/test_linear.py b/test/units/plugins/strategy/test_linear.py new file mode 100644 index 00000000..74887030 --- /dev/null +++ b/test/units/plugins/strategy/test_linear.py @@ -0,0 +1,177 @@ +# Copyright (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +from units.compat import unittest +from units.compat.mock import patch, MagicMock + +from ansible.executor.play_iterator import PlayIterator +from ansible.playbook import Playbook +from ansible.playbook.play_context import PlayContext +from ansible.plugins.strategy.linear import StrategyModule +from ansible.executor.task_queue_manager import TaskQueueManager + +from units.mock.loader import DictDataLoader +from units.mock.path import mock_unfrackpath_noop + + +class TestStrategyLinear(unittest.TestCase): + + @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) + def test_noop(self): + fake_loader = DictDataLoader({ + "test_play.yml": """ + - hosts: all + gather_facts: no + tasks: + - block: + - block: + - name: task1 + debug: msg='task1' + failed_when: inventory_hostname == 'host01' + + - name: task2 + debug: msg='task2' + + rescue: + - name: rescue1 + debug: msg='rescue1' + + - name: rescue2 + debug: msg='rescue2' + """, + }) + + mock_var_manager = MagicMock() + mock_var_manager._fact_cache = dict() + mock_var_manager.get_vars.return_value = dict() + + p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager) + + inventory = MagicMock() + inventory.hosts = {} + hosts = [] + for i in range(0, 2): + host = MagicMock() + host.name = host.get_name.return_value = 'host%02d' % i + hosts.append(host) + inventory.hosts[host.name] = host + inventory.get_hosts.return_value = hosts + inventory.filter_hosts.return_value = hosts + + mock_var_manager._fact_cache['host00'] = dict() + + play_context = PlayContext(play=p._entries[0]) + + itr = PlayIterator( + inventory=inventory, + play=p._entries[0], + play_context=play_context, + variable_manager=mock_var_manager, + all_vars=dict(), + ) + + tqm = TaskQueueManager( + inventory=inventory, + variable_manager=mock_var_manager, + loader=fake_loader, + passwords=None, + forks=5, + ) + tqm._initialize_processes(3) + strategy = StrategyModule(tqm) + strategy._hosts_cache = [h.name for h in hosts] + strategy._hosts_cache_all = [h.name for h in hosts] + + # implicit meta: flush_handlers + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'meta') + self.assertEqual(host2_task.action, 'meta') + + # debug: task1, debug: task1 + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'debug') + self.assertEqual(host2_task.action, 'debug') + self.assertEqual(host1_task.name, 'task1') + self.assertEqual(host2_task.name, 'task1') + + # mark the second host failed + itr.mark_host_failed(hosts[1]) + + # debug: task2, meta: noop + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'debug') + self.assertEqual(host2_task.action, 'meta') + self.assertEqual(host1_task.name, 'task2') + self.assertEqual(host2_task.name, '') + + # meta: noop, debug: rescue1 + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'meta') + self.assertEqual(host2_task.action, 'debug') + self.assertEqual(host1_task.name, '') + self.assertEqual(host2_task.name, 'rescue1') + + # meta: noop, debug: rescue2 + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'meta') + self.assertEqual(host2_task.action, 'debug') + self.assertEqual(host1_task.name, '') + self.assertEqual(host2_task.name, 'rescue2') + + # implicit meta: flush_handlers + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'meta') + self.assertEqual(host2_task.action, 'meta') + + # implicit meta: flush_handlers + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNotNone(host1_task) + self.assertIsNotNone(host2_task) + self.assertEqual(host1_task.action, 'meta') + self.assertEqual(host2_task.action, 'meta') + + # end of iteration + hosts_left = strategy.get_hosts_left(itr) + hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr) + host1_task = hosts_tasks[0][1] + host2_task = hosts_tasks[1][1] + self.assertIsNone(host1_task) + self.assertIsNone(host2_task) diff --git a/test/units/plugins/strategy/test_strategy.py b/test/units/plugins/strategy/test_strategy.py new file mode 100644 index 00000000..9a2574d2 --- /dev/null +++ b/test/units/plugins/strategy/test_strategy.py @@ -0,0 +1,546 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from units.mock.loader import DictDataLoader +from copy import deepcopy +import uuid + +from units.compat import unittest +from units.compat.mock import patch, MagicMock +from ansible.executor.process.worker import WorkerProcess +from ansible.executor.task_queue_manager import TaskQueueManager +from ansible.executor.task_result import TaskResult +from ansible.inventory.host import Host +from ansible.module_utils.six.moves import queue as Queue +from ansible.playbook.handler import Handler +from ansible.plugins.strategy import StrategyBase + + +class TestStrategyBase(unittest.TestCase): + + def test_strategy_base_init(self): + queue_items = [] + + def _queue_empty(*args, **kwargs): + return len(queue_items) == 0 + + def _queue_get(*args, **kwargs): + if len(queue_items) == 0: + raise Queue.Empty + else: + return queue_items.pop() + + def _queue_put(item, *args, **kwargs): + queue_items.append(item) + + mock_queue = MagicMock() + mock_queue.empty.side_effect = _queue_empty + mock_queue.get.side_effect = _queue_get + mock_queue.put.side_effect = _queue_put + + mock_tqm = MagicMock(TaskQueueManager) + mock_tqm._final_q = mock_queue + mock_tqm._workers = [] + strategy_base = StrategyBase(tqm=mock_tqm) + strategy_base.cleanup() + + def test_strategy_base_run(self): + queue_items = [] + + def _queue_empty(*args, **kwargs): + return len(queue_items) == 0 + + def _queue_get(*args, **kwargs): + if len(queue_items) == 0: + raise Queue.Empty + else: + return queue_items.pop() + + def _queue_put(item, *args, **kwargs): + queue_items.append(item) + + mock_queue = MagicMock() + mock_queue.empty.side_effect = _queue_empty + mock_queue.get.side_effect = _queue_get + mock_queue.put.side_effect = _queue_put + + mock_tqm = MagicMock(TaskQueueManager) + mock_tqm._final_q = mock_queue + mock_tqm._stats = MagicMock() + mock_tqm.send_callback.return_value = None + + for attr in ('RUN_OK', 'RUN_ERROR', 'RUN_FAILED_HOSTS', 'RUN_UNREACHABLE_HOSTS'): + setattr(mock_tqm, attr, getattr(TaskQueueManager, attr)) + + mock_iterator = MagicMock() + mock_iterator._play = MagicMock() + mock_iterator._play.handlers = [] + + mock_play_context = MagicMock() + + mock_tqm._failed_hosts = dict() + mock_tqm._unreachable_hosts = dict() + mock_tqm._workers = [] + strategy_base = StrategyBase(tqm=mock_tqm) + + mock_host = MagicMock() + mock_host.name = 'host1' + + self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context), mock_tqm.RUN_OK) + self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context, result=TaskQueueManager.RUN_ERROR), mock_tqm.RUN_ERROR) + mock_tqm._failed_hosts = dict(host1=True) + mock_iterator.get_failed_hosts.return_value = [mock_host] + self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context, result=False), mock_tqm.RUN_FAILED_HOSTS) + mock_tqm._unreachable_hosts = dict(host1=True) + mock_iterator.get_failed_hosts.return_value = [] + self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context, result=False), mock_tqm.RUN_UNREACHABLE_HOSTS) + strategy_base.cleanup() + + def test_strategy_base_get_hosts(self): + queue_items = [] + + def _queue_empty(*args, **kwargs): + return len(queue_items) == 0 + + def _queue_get(*args, **kwargs): + if len(queue_items) == 0: + raise Queue.Empty + else: + return queue_items.pop() + + def _queue_put(item, *args, **kwargs): + queue_items.append(item) + + mock_queue = MagicMock() + mock_queue.empty.side_effect = _queue_empty + mock_queue.get.side_effect = _queue_get + mock_queue.put.side_effect = _queue_put + + mock_hosts = [] + for i in range(0, 5): + mock_host = MagicMock() + mock_host.name = "host%02d" % (i + 1) + mock_host.has_hostkey = True + mock_hosts.append(mock_host) + + mock_hosts_names = [h.name for h in mock_hosts] + + mock_inventory = MagicMock() + mock_inventory.get_hosts.return_value = mock_hosts + + mock_tqm = MagicMock() + mock_tqm._final_q = mock_queue + mock_tqm.get_inventory.return_value = mock_inventory + + mock_play = MagicMock() + mock_play.hosts = ["host%02d" % (i + 1) for i in range(0, 5)] + + strategy_base = StrategyBase(tqm=mock_tqm) + strategy_base._hosts_cache = strategy_base._hosts_cache_all = mock_hosts_names + + mock_tqm._failed_hosts = [] + mock_tqm._unreachable_hosts = [] + self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts]) + + mock_tqm._failed_hosts = ["host01"] + self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[1:]]) + self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0].name]) + + mock_tqm._unreachable_hosts = ["host02"] + self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[2:]]) + strategy_base.cleanup() + + @patch.object(WorkerProcess, 'run') + def test_strategy_base_queue_task(self, mock_worker): + def fake_run(self): + return + + mock_worker.run.side_effect = fake_run + + fake_loader = DictDataLoader() + mock_var_manager = MagicMock() + mock_host = MagicMock() + mock_host.get_vars.return_value = dict() + mock_host.has_hostkey = True + mock_inventory = MagicMock() + mock_inventory.get.return_value = mock_host + + tqm = TaskQueueManager( + inventory=mock_inventory, + variable_manager=mock_var_manager, + loader=fake_loader, + passwords=None, + forks=3, + ) + tqm._initialize_processes(3) + tqm.hostvars = dict() + + mock_task = MagicMock() + mock_task._uuid = 'abcd' + mock_task.throttle = 0 + + try: + strategy_base = StrategyBase(tqm=tqm) + strategy_base._queue_task(host=mock_host, task=mock_task, task_vars=dict(), play_context=MagicMock()) + self.assertEqual(strategy_base._cur_worker, 1) + self.assertEqual(strategy_base._pending_results, 1) + strategy_base._queue_task(host=mock_host, task=mock_task, task_vars=dict(), play_context=MagicMock()) + self.assertEqual(strategy_base._cur_worker, 2) + self.assertEqual(strategy_base._pending_results, 2) + strategy_base._queue_task(host=mock_host, task=mock_task, task_vars=dict(), play_context=MagicMock()) + self.assertEqual(strategy_base._cur_worker, 0) + self.assertEqual(strategy_base._pending_results, 3) + finally: + tqm.cleanup() + + def test_strategy_base_process_pending_results(self): + mock_tqm = MagicMock() + mock_tqm._terminated = False + mock_tqm._failed_hosts = dict() + mock_tqm._unreachable_hosts = dict() + mock_tqm.send_callback.return_value = None + + queue_items = [] + + def _queue_empty(*args, **kwargs): + return len(queue_items) == 0 + + def _queue_get(*args, **kwargs): + if len(queue_items) == 0: + raise Queue.Empty + else: + return queue_items.pop() + + def _queue_put(item, *args, **kwargs): + queue_items.append(item) + + mock_queue = MagicMock() + mock_queue.empty.side_effect = _queue_empty + mock_queue.get.side_effect = _queue_get + mock_queue.put.side_effect = _queue_put + mock_tqm._final_q = mock_queue + + mock_tqm._stats = MagicMock() + mock_tqm._stats.increment.return_value = None + + mock_play = MagicMock() + + mock_host = MagicMock() + mock_host.name = 'test01' + mock_host.vars = dict() + mock_host.get_vars.return_value = dict() + mock_host.has_hostkey = True + + mock_task = MagicMock() + mock_task._role = None + mock_task._parent = None + mock_task.ignore_errors = False + mock_task.ignore_unreachable = False + mock_task._uuid = uuid.uuid4() + mock_task.loop = None + mock_task.copy.return_value = mock_task + + mock_handler_task = Handler() + mock_handler_task.name = 'test handler' + mock_handler_task.action = 'foo' + mock_handler_task._parent = None + mock_handler_task._uuid = 'xxxxxxxxxxxxx' + + mock_iterator = MagicMock() + mock_iterator._play = mock_play + mock_iterator.mark_host_failed.return_value = None + mock_iterator.get_next_task_for_host.return_value = (None, None) + + mock_handler_block = MagicMock() + mock_handler_block.block = [mock_handler_task] + mock_handler_block.rescue = [] + mock_handler_block.always = [] + mock_play.handlers = [mock_handler_block] + + mock_group = MagicMock() + mock_group.add_host.return_value = None + + def _get_host(host_name): + if host_name == 'test01': + return mock_host + return None + + def _get_group(group_name): + if group_name in ('all', 'foo'): + return mock_group + return None + + mock_inventory = MagicMock() + mock_inventory._hosts_cache = dict() + mock_inventory.hosts.return_value = mock_host + mock_inventory.get_host.side_effect = _get_host + mock_inventory.get_group.side_effect = _get_group + mock_inventory.clear_pattern_cache.return_value = None + mock_inventory.get_host_vars.return_value = {} + mock_inventory.hosts.get.return_value = mock_host + + mock_var_mgr = MagicMock() + mock_var_mgr.set_host_variable.return_value = None + mock_var_mgr.set_host_facts.return_value = None + mock_var_mgr.get_vars.return_value = dict() + + strategy_base = StrategyBase(tqm=mock_tqm) + strategy_base._inventory = mock_inventory + strategy_base._variable_manager = mock_var_mgr + strategy_base._blocked_hosts = dict() + + def _has_dead_workers(): + return False + + strategy_base._tqm.has_dead_workers.side_effect = _has_dead_workers + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 0) + + task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(changed=True)) + queue_items.append(task_result) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + + mock_queued_task_cache = { + (mock_host.name, mock_task._uuid): { + 'task': mock_task, + 'host': mock_host, + 'task_vars': {}, + 'play_context': {}, + } + } + + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], task_result) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + + task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data='{"failed":true}') + queue_items.append(task_result) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + mock_iterator.is_failed.return_value = True + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], task_result) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + # self.assertIn('test01', mock_tqm._failed_hosts) + # del mock_tqm._failed_hosts['test01'] + mock_iterator.is_failed.return_value = False + + task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data='{"unreachable": true}') + queue_items.append(task_result) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], task_result) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + self.assertIn('test01', mock_tqm._unreachable_hosts) + del mock_tqm._unreachable_hosts['test01'] + + task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data='{"skipped": true}') + queue_items.append(task_result) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], task_result) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + + queue_items.append(TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(add_host=dict(host_name='newhost01', new_groups=['foo'])))) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + + queue_items.append(TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(add_group=dict(group_name='foo')))) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + + queue_items.append(TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(changed=True, _ansible_notify=['test handler']))) + strategy_base._blocked_hosts['test01'] = True + strategy_base._pending_results = 1 + strategy_base._queued_task_cache = deepcopy(mock_queued_task_cache) + results = strategy_base._wait_on_pending_results(iterator=mock_iterator) + self.assertEqual(len(results), 1) + self.assertEqual(strategy_base._pending_results, 0) + self.assertNotIn('test01', strategy_base._blocked_hosts) + self.assertTrue(mock_handler_task.is_host_notified(mock_host)) + + # queue_items.append(('set_host_var', mock_host, mock_task, None, 'foo', 'bar')) + # results = strategy_base._process_pending_results(iterator=mock_iterator) + # self.assertEqual(len(results), 0) + # self.assertEqual(strategy_base._pending_results, 1) + + # queue_items.append(('set_host_facts', mock_host, mock_task, None, 'foo', dict())) + # results = strategy_base._process_pending_results(iterator=mock_iterator) + # self.assertEqual(len(results), 0) + # self.assertEqual(strategy_base._pending_results, 1) + + # queue_items.append(('bad')) + # self.assertRaises(AnsibleError, strategy_base._process_pending_results, iterator=mock_iterator) + strategy_base.cleanup() + + def test_strategy_base_load_included_file(self): + fake_loader = DictDataLoader({ + "test.yml": """ + - debug: msg='foo' + """, + "bad.yml": """ + """, + }) + + queue_items = [] + + def _queue_empty(*args, **kwargs): + return len(queue_items) == 0 + + def _queue_get(*args, **kwargs): + if len(queue_items) == 0: + raise Queue.Empty + else: + return queue_items.pop() + + def _queue_put(item, *args, **kwargs): + queue_items.append(item) + + mock_queue = MagicMock() + mock_queue.empty.side_effect = _queue_empty + mock_queue.get.side_effect = _queue_get + mock_queue.put.side_effect = _queue_put + + mock_tqm = MagicMock() + mock_tqm._final_q = mock_queue + + strategy_base = StrategyBase(tqm=mock_tqm) + strategy_base._loader = fake_loader + strategy_base.cleanup() + + mock_play = MagicMock() + + mock_block = MagicMock() + mock_block._play = mock_play + mock_block.vars = dict() + + mock_task = MagicMock() + mock_task._block = mock_block + mock_task._role = None + mock_task._parent = None + + mock_iterator = MagicMock() + mock_iterator.mark_host_failed.return_value = None + + mock_inc_file = MagicMock() + mock_inc_file._task = mock_task + + mock_inc_file._filename = "test.yml" + res = strategy_base._load_included_file(included_file=mock_inc_file, iterator=mock_iterator) + + mock_inc_file._filename = "bad.yml" + res = strategy_base._load_included_file(included_file=mock_inc_file, iterator=mock_iterator) + self.assertEqual(res, []) + + @patch.object(WorkerProcess, 'run') + def test_strategy_base_run_handlers(self, mock_worker): + def fake_run(*args): + return + mock_worker.side_effect = fake_run + mock_play_context = MagicMock() + + mock_handler_task = Handler() + mock_handler_task.action = 'foo' + mock_handler_task.cached_name = False + mock_handler_task.name = "test handler" + mock_handler_task.listen = [] + mock_handler_task._role = None + mock_handler_task._parent = None + mock_handler_task._uuid = 'xxxxxxxxxxxxxxxx' + + mock_handler = MagicMock() + mock_handler.block = [mock_handler_task] + mock_handler.flag_for_host.return_value = False + + mock_play = MagicMock() + mock_play.handlers = [mock_handler] + + mock_host = MagicMock(Host) + mock_host.name = "test01" + mock_host.has_hostkey = True + + mock_inventory = MagicMock() + mock_inventory.get_hosts.return_value = [mock_host] + mock_inventory.get.return_value = mock_host + mock_inventory.get_host.return_value = mock_host + + mock_var_mgr = MagicMock() + mock_var_mgr.get_vars.return_value = dict() + + mock_iterator = MagicMock() + mock_iterator._play = mock_play + + fake_loader = DictDataLoader() + + tqm = TaskQueueManager( + inventory=mock_inventory, + variable_manager=mock_var_mgr, + loader=fake_loader, + passwords=None, + forks=5, + ) + tqm._initialize_processes(3) + tqm.hostvars = dict() + + try: + strategy_base = StrategyBase(tqm=tqm) + + strategy_base._inventory = mock_inventory + + task_result = TaskResult(mock_host.name, mock_handler_task._uuid, dict(changed=False)) + strategy_base._queued_task_cache = dict() + strategy_base._queued_task_cache[(mock_host.name, mock_handler_task._uuid)] = { + 'task': mock_handler_task, + 'host': mock_host, + 'task_vars': {}, + 'play_context': mock_play_context + } + tqm._final_q.put(task_result) + + result = strategy_base.run_handlers(iterator=mock_iterator, play_context=mock_play_context) + finally: + strategy_base.cleanup() + tqm.cleanup() diff --git a/test/units/plugins/test_plugins.py b/test/units/plugins/test_plugins.py new file mode 100644 index 00000000..c9d80cda --- /dev/null +++ b/test/units/plugins/test_plugins.py @@ -0,0 +1,134 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from units.compat import unittest +from units.compat.builtins import BUILTINS +from units.compat.mock import patch, MagicMock +from ansible.plugins.loader import PluginLoader, PluginPathContext + + +class TestErrors(unittest.TestCase): + + @patch.object(PluginLoader, '_get_paths') + def test_print_paths(self, mock_method): + mock_method.return_value = ['/path/one', '/path/two', '/path/three'] + pl = PluginLoader('foo', 'foo', '', 'test_plugins') + paths = pl.print_paths() + expected_paths = os.pathsep.join(['/path/one', '/path/two', '/path/three']) + self.assertEqual(paths, expected_paths) + + def test_plugins__get_package_paths_no_package(self): + pl = PluginLoader('test', '', 'test', 'test_plugin') + self.assertEqual(pl._get_package_paths(), []) + + def test_plugins__get_package_paths_with_package(self): + # the _get_package_paths() call uses __import__ to load a + # python library, and then uses the __file__ attribute of + # the result for that to get the library path, so we mock + # that here and patch the builtin to use our mocked result + foo = MagicMock() + bar = MagicMock() + bam = MagicMock() + bam.__file__ = '/path/to/my/foo/bar/bam/__init__.py' + bar.bam = bam + foo.return_value.bar = bar + pl = PluginLoader('test', 'foo.bar.bam', 'test', 'test_plugin') + with patch('{0}.__import__'.format(BUILTINS), foo): + self.assertEqual(pl._get_package_paths(), ['/path/to/my/foo/bar/bam']) + + def test_plugins__get_paths(self): + pl = PluginLoader('test', '', 'test', 'test_plugin') + pl._paths = [PluginPathContext('/path/one', False), + PluginPathContext('/path/two', True)] + self.assertEqual(pl._get_paths(), ['/path/one', '/path/two']) + + # NOT YET WORKING + # def fake_glob(path): + # if path == 'test/*': + # return ['test/foo', 'test/bar', 'test/bam'] + # elif path == 'test/*/*' + # m._paths = None + # mock_glob = MagicMock() + # mock_glob.return_value = [] + # with patch('glob.glob', mock_glob): + # pass + + def assertPluginLoaderConfigBecomes(self, arg, expected): + pl = PluginLoader('test', '', arg, 'test_plugin') + self.assertEqual(pl.config, expected) + + def test_plugin__init_config_list(self): + config = ['/one', '/two'] + self.assertPluginLoaderConfigBecomes(config, config) + + def test_plugin__init_config_str(self): + self.assertPluginLoaderConfigBecomes('test', ['test']) + + def test_plugin__init_config_none(self): + self.assertPluginLoaderConfigBecomes(None, []) + + def test__load_module_source_no_duplicate_names(self): + ''' + This test simulates importing 2 plugins with the same name, + and validating that the import is short circuited if a file with the same name + has already been imported + ''' + + fixture_path = os.path.join(os.path.dirname(__file__), 'loader_fixtures') + + pl = PluginLoader('test', '', 'test', 'test_plugin') + one = pl._load_module_source('import_fixture', os.path.join(fixture_path, 'import_fixture.py')) + # This line wouldn't even succeed if we didn't short circuit on finding a duplicate name + two = pl._load_module_source('import_fixture', '/path/to/import_fixture.py') + + self.assertEqual(one, two) + + @patch('ansible.plugins.loader.glob') + @patch.object(PluginLoader, '_get_paths') + def test_all_no_duplicate_names(self, gp_mock, glob_mock): + ''' + This test goes along with ``test__load_module_source_no_duplicate_names`` + and ensures that we ignore duplicate imports on multiple paths + ''' + + fixture_path = os.path.join(os.path.dirname(__file__), 'loader_fixtures') + + gp_mock.return_value = [ + fixture_path, + '/path/to' + ] + + glob_mock.glob.side_effect = [ + [os.path.join(fixture_path, 'import_fixture.py')], + ['/path/to/import_fixture.py'] + ] + + pl = PluginLoader('test', '', 'test', 'test_plugin') + # Aside from needing ``list()`` so we can do a len, ``PluginLoader.all`` returns a generator + # so ``list()`` actually causes ``PluginLoader.all`` to run. + plugins = list(pl.all()) + self.assertEqual(len(plugins), 1) + + self.assertIn(os.path.join(fixture_path, 'import_fixture.py'), pl._module_cache) + self.assertNotIn('/path/to/import_fixture.py', pl._module_cache) |