summaryrefslogtreecommitdiffstats
path: root/test/units/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'test/units/plugins')
-rw-r--r--test/units/plugins/__init__.py0
-rw-r--r--test/units/plugins/action/__init__.py0
-rw-r--r--test/units/plugins/action/test_action.py912
-rw-r--r--test/units/plugins/action/test_gather_facts.py98
-rw-r--r--test/units/plugins/action/test_pause.py89
-rw-r--r--test/units/plugins/action/test_raw.py105
-rw-r--r--test/units/plugins/become/__init__.py0
-rw-r--r--test/units/plugins/become/conftest.py37
-rw-r--r--test/units/plugins/become/test_su.py30
-rw-r--r--test/units/plugins/become/test_sudo.py67
-rw-r--r--test/units/plugins/cache/__init__.py0
-rw-r--r--test/units/plugins/cache/test_cache.py199
-rw-r--r--test/units/plugins/callback/__init__.py0
-rw-r--r--test/units/plugins/callback/test_callback.py416
-rw-r--r--test/units/plugins/connection/__init__.py0
-rw-r--r--test/units/plugins/connection/test_connection.py163
-rw-r--r--test/units/plugins/connection/test_local.py40
-rw-r--r--test/units/plugins/connection/test_paramiko.py56
-rw-r--r--test/units/plugins/connection/test_psrp.py233
-rw-r--r--test/units/plugins/connection/test_ssh.py696
-rw-r--r--test/units/plugins/connection/test_winrm.py443
-rw-r--r--test/units/plugins/filter/__init__.py0
-rw-r--r--test/units/plugins/filter/test_core.py43
-rw-r--r--test/units/plugins/filter/test_mathstuff.py162
-rw-r--r--test/units/plugins/inventory/__init__.py0
-rw-r--r--test/units/plugins/inventory/test_constructed.py337
-rw-r--r--test/units/plugins/inventory/test_inventory.py208
-rw-r--r--test/units/plugins/inventory/test_script.py105
-rw-r--r--test/units/plugins/loader_fixtures/__init__.py0
-rw-r--r--test/units/plugins/loader_fixtures/import_fixture.py9
-rw-r--r--test/units/plugins/lookup/__init__.py0
-rw-r--r--test/units/plugins/lookup/test_env.py35
-rw-r--r--test/units/plugins/lookup/test_ini.py64
-rw-r--r--test/units/plugins/lookup/test_password.py577
-rw-r--r--test/units/plugins/lookup/test_url.py26
-rw-r--r--test/units/plugins/shell/__init__.py0
-rw-r--r--test/units/plugins/shell/test_cmd.py19
-rw-r--r--test/units/plugins/shell/test_powershell.py83
-rw-r--r--test/units/plugins/strategy/__init__.py0
-rw-r--r--test/units/plugins/strategy/test_linear.py320
-rw-r--r--test/units/plugins/strategy/test_strategy.py492
-rw-r--r--test/units/plugins/test_plugins.py133
42 files changed, 6197 insertions, 0 deletions
diff --git a/test/units/plugins/__init__.py b/test/units/plugins/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/__init__.py
diff --git a/test/units/plugins/action/__init__.py b/test/units/plugins/action/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/action/__init__.py
diff --git a/test/units/plugins/action/test_action.py b/test/units/plugins/action/test_action.py
new file mode 100644
index 0000000..f2bbe19
--- /dev/null
+++ b/test/units/plugins/action/test_action.py
@@ -0,0 +1,912 @@
+# -*- coding: utf-8 -*-
+# (c) 2015, Florian Apolloner <florian@apolloner.eu>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+import re
+
+from ansible import constants as C
+from units.compat import unittest
+from unittest.mock import patch, MagicMock, mock_open
+
+from ansible.errors import AnsibleError, AnsibleAuthenticationFailure
+from ansible.module_utils.six import text_type
+from ansible.module_utils.six.moves import shlex_quote, builtins
+from ansible.module_utils._text import to_bytes
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.action import ActionBase
+from ansible.template import Templar
+from ansible.vars.clean import clean_facts
+
+from units.mock.loader import DictDataLoader
+
+
+python_module_replacers = br"""
+#!/usr/bin/python
+
+#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>"
+#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>"
+#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>"
+
+test = u'Toshio \u304f\u3089\u3068\u307f'
+from ansible.module_utils.basic import *
+"""
+
+powershell_module_replacers = b"""
+WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
+# POWERSHELL_COMMON
+"""
+
+
+def _action_base():
+ fake_loader = DictDataLoader({
+ })
+ mock_module_loader = MagicMock()
+ mock_shared_loader_obj = MagicMock()
+ mock_shared_loader_obj.module_loader = mock_module_loader
+ mock_connection_loader = MagicMock()
+
+ mock_shared_loader_obj.connection_loader = mock_connection_loader
+ mock_connection = MagicMock()
+
+ play_context = MagicMock()
+
+ action_base = DerivedActionBase(task=None,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=fake_loader,
+ templar=None,
+ shared_loader_obj=mock_shared_loader_obj)
+ return action_base
+
+
+class DerivedActionBase(ActionBase):
+ TRANSFERS_FILES = False
+
+ def run(self, tmp=None, task_vars=None):
+ # We're not testing the plugin run() method, just the helper
+ # methods ActionBase defines
+ return super(DerivedActionBase, self).run(tmp=tmp, task_vars=task_vars)
+
+
+class TestActionBase(unittest.TestCase):
+
+ def test_action_base_run(self):
+ mock_task = MagicMock()
+ mock_task.action = "foo"
+ mock_task.args = dict(a=1, b=2, c=3)
+
+ mock_connection = MagicMock()
+
+ play_context = PlayContext()
+
+ mock_task.async_val = None
+ action_base = DerivedActionBase(mock_task, mock_connection, play_context, None, None, None)
+ results = action_base.run()
+ self.assertEqual(results, dict())
+
+ mock_task.async_val = 0
+ action_base = DerivedActionBase(mock_task, mock_connection, play_context, None, None, None)
+ results = action_base.run()
+ self.assertEqual(results, {})
+
+ def test_action_base__configure_module(self):
+ fake_loader = DictDataLoader({
+ })
+
+ # create our fake task
+ mock_task = MagicMock()
+ mock_task.action = "copy"
+ mock_task.async_val = 0
+ mock_task.delegate_to = None
+
+ # create a mock connection, so we don't actually try and connect to things
+ mock_connection = MagicMock()
+
+ # create a mock shared loader object
+ def mock_find_plugin_with_context(name, options, collection_list=None):
+ mockctx = MagicMock()
+ if name == 'badmodule':
+ mockctx.resolved = False
+ mockctx.plugin_resolved_path = None
+ elif '.ps1' in options:
+ mockctx.resolved = True
+ mockctx.plugin_resolved_path = '/fake/path/to/%s.ps1' % name
+ else:
+ mockctx.resolved = True
+ mockctx.plugin_resolved_path = '/fake/path/to/%s' % name
+ return mockctx
+
+ mock_module_loader = MagicMock()
+ mock_module_loader.find_plugin_with_context.side_effect = mock_find_plugin_with_context
+ mock_shared_obj_loader = MagicMock()
+ mock_shared_obj_loader.module_loader = mock_module_loader
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=fake_loader,
+ templar=Templar(loader=fake_loader),
+ shared_loader_obj=mock_shared_obj_loader,
+ )
+
+ # test python module formatting
+ with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))):
+ with patch.object(os, 'rename'):
+ mock_task.args = dict(a=1, foo='fö〩')
+ mock_connection.module_implementation_preferences = ('',)
+ (style, shebang, data, path) = action_base._configure_module(mock_task.action, mock_task.args,
+ task_vars=dict(ansible_python_interpreter='/usr/bin/python',
+ ansible_playbook_python='/usr/bin/python'))
+ self.assertEqual(style, "new")
+ self.assertEqual(shebang, u"#!/usr/bin/python")
+
+ # test module not found
+ self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args, {})
+
+ # test powershell module formatting
+ with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))):
+ mock_task.action = 'win_copy'
+ mock_task.args = dict(b=2)
+ mock_connection.module_implementation_preferences = ('.ps1',)
+ (style, shebang, data, path) = action_base._configure_module('stat', mock_task.args, {})
+ self.assertEqual(style, "new")
+ self.assertEqual(shebang, u'#!powershell')
+
+ # test module not found
+ self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args, {})
+
+ def test_action_base__compute_environment_string(self):
+ fake_loader = DictDataLoader({
+ })
+
+ # create our fake task
+ mock_task = MagicMock()
+ mock_task.action = "copy"
+ mock_task.args = dict(a=1)
+
+ # create a mock connection, so we don't actually try and connect to things
+ def env_prefix(**args):
+ return ' '.join(['%s=%s' % (k, shlex_quote(text_type(v))) for k, v in args.items()])
+ mock_connection = MagicMock()
+ mock_connection._shell.env_prefix.side_effect = env_prefix
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # and we're using a real templar here too
+ templar = Templar(loader=fake_loader)
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=fake_loader,
+ templar=templar,
+ shared_loader_obj=None,
+ )
+
+ # test standard environment setup
+ mock_task.environment = [dict(FOO='foo'), None]
+ env_string = action_base._compute_environment_string()
+ self.assertEqual(env_string, "FOO=foo")
+
+ # test where environment is not a list
+ mock_task.environment = dict(FOO='foo')
+ env_string = action_base._compute_environment_string()
+ self.assertEqual(env_string, "FOO=foo")
+
+ # test environment with a variable in it
+ templar.available_variables = dict(the_var='bar')
+ mock_task.environment = [dict(FOO='{{the_var}}')]
+ env_string = action_base._compute_environment_string()
+ self.assertEqual(env_string, "FOO=bar")
+
+ # test with a bad environment set
+ mock_task.environment = dict(FOO='foo')
+ mock_task.environment = ['hi there']
+ self.assertRaises(AnsibleError, action_base._compute_environment_string)
+
+ def test_action_base__early_needs_tmp_path(self):
+ # create our fake task
+ mock_task = MagicMock()
+
+ # create a mock connection, so we don't actually try and connect to things
+ mock_connection = MagicMock()
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+
+ self.assertFalse(action_base._early_needs_tmp_path())
+
+ action_base.TRANSFERS_FILES = True
+ self.assertTrue(action_base._early_needs_tmp_path())
+
+ def test_action_base__make_tmp_path(self):
+ # create our fake task
+ mock_task = MagicMock()
+
+ def get_shell_opt(opt):
+
+ ret = None
+ if opt == 'admin_users':
+ ret = ['root', 'toor', 'Administrator']
+ elif opt == 'remote_tmp':
+ ret = '~/.ansible/tmp'
+
+ return ret
+
+ # create a mock connection, so we don't actually try and connect to things
+ mock_connection = MagicMock()
+ mock_connection.transport = 'ssh'
+ mock_connection._shell.mkdtemp.return_value = 'mkdir command'
+ mock_connection._shell.join_path.side_effect = os.path.join
+ mock_connection._shell.get_option = get_shell_opt
+ mock_connection._shell.HOMES_RE = re.compile(r'(\'|\")?(~|\$HOME)(.*)')
+
+ # we're using a real play context here
+ play_context = PlayContext()
+ play_context.become = True
+ play_context.become_user = 'foo'
+
+ mock_task.become = True
+ mock_task.become_user = True
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+
+ action_base._low_level_execute_command = MagicMock()
+ action_base._low_level_execute_command.return_value = dict(rc=0, stdout='/some/path')
+ self.assertEqual(action_base._make_tmp_path('root'), '/some/path/')
+
+ # empty path fails
+ action_base._low_level_execute_command.return_value = dict(rc=0, stdout='')
+ self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root')
+
+ # authentication failure
+ action_base._low_level_execute_command.return_value = dict(rc=5, stdout='')
+ self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root')
+
+ # ssh error
+ action_base._low_level_execute_command.return_value = dict(rc=255, stdout='', stderr='')
+ self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root')
+ self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root')
+
+ # general error
+ action_base._low_level_execute_command.return_value = dict(rc=1, stdout='some stuff here', stderr='')
+ self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root')
+ action_base._low_level_execute_command.return_value = dict(rc=1, stdout='some stuff here', stderr='No space left on device')
+ self.assertRaises(AnsibleError, action_base._make_tmp_path, 'root')
+
+ def test_action_base__fixup_perms2(self):
+ mock_task = MagicMock()
+ mock_connection = MagicMock()
+ play_context = PlayContext()
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+ action_base._low_level_execute_command = MagicMock()
+ remote_paths = ['/tmp/foo/bar.txt', '/tmp/baz.txt']
+ remote_user = 'remoteuser1'
+
+ # Used for skipping down to common group dir.
+ CHMOD_ACL_FLAGS = ('+a', 'A+user:remoteuser2:r:allow')
+
+ def runWithNoExpectation(execute=False):
+ return action_base._fixup_perms2(
+ remote_paths,
+ remote_user=remote_user,
+ execute=execute)
+
+ def assertSuccess(execute=False):
+ self.assertEqual(runWithNoExpectation(execute), remote_paths)
+
+ def assertThrowRegex(regex, execute=False):
+ self.assertRaisesRegex(
+ AnsibleError,
+ regex,
+ action_base._fixup_perms2,
+ remote_paths,
+ remote_user=remote_user,
+ execute=execute)
+
+ def get_shell_option_for_arg(args_kv, default):
+ '''A helper for get_shell_option. Returns a function that, if
+ called with ``option`` that exists in args_kv, will return the
+ value, else will return ``default`` for every other given arg'''
+ def _helper(option, *args, **kwargs):
+ return args_kv.get(option, default)
+ return _helper
+
+ action_base.get_become_option = MagicMock()
+ action_base.get_become_option.return_value = 'remoteuser2'
+
+ # Step 1: On windows, we just return remote_paths
+ action_base._connection._shell._IS_WINDOWS = True
+ assertSuccess(execute=False)
+ assertSuccess(execute=True)
+
+ # But if we're not on windows....we have more work to do.
+ action_base._connection._shell._IS_WINDOWS = False
+
+ # Step 2: We're /not/ becoming an unprivileged user
+ action_base._remote_chmod = MagicMock()
+ action_base._is_become_unprivileged = MagicMock()
+ action_base._is_become_unprivileged.return_value = False
+ # Two subcases:
+ # - _remote_chmod rc is 0
+ # - _remote-chmod rc is not 0, something failed
+ action_base._remote_chmod.return_value = {
+ 'rc': 0,
+ 'stdout': 'some stuff here',
+ 'stderr': '',
+ }
+ assertSuccess(execute=True)
+
+ # When execute=False, we just get the list back. But add it here for
+ # completion. chmod is never called.
+ assertSuccess()
+
+ action_base._remote_chmod.return_value = {
+ 'rc': 1,
+ 'stdout': 'some stuff here',
+ 'stderr': 'and here',
+ }
+ assertThrowRegex(
+ 'Failed to set execute bit on remote files',
+ execute=True)
+
+ # Step 3: we are becoming unprivileged
+ action_base._is_become_unprivileged.return_value = True
+
+ # Step 3a: setfacl
+ action_base._remote_set_user_facl = MagicMock()
+ action_base._remote_set_user_facl.return_value = {
+ 'rc': 0,
+ 'stdout': '',
+ 'stderr': '',
+ }
+ assertSuccess()
+
+ # Step 3b: chmod +x if we need to
+ # To get here, setfacl failed, so mock it as such.
+ action_base._remote_set_user_facl.return_value = {
+ 'rc': 1,
+ 'stdout': '',
+ 'stderr': '',
+ }
+ action_base._remote_chmod.return_value = {
+ 'rc': 1,
+ 'stdout': 'some stuff here',
+ 'stderr': '',
+ }
+ assertThrowRegex(
+ 'Failed to set file mode or acl on remote temporary files',
+ execute=True)
+ action_base._remote_chmod.return_value = {
+ 'rc': 0,
+ 'stdout': 'some stuff here',
+ 'stderr': '',
+ }
+ assertSuccess(execute=True)
+
+ # Step 3c: chown
+ action_base._remote_chown = MagicMock()
+ action_base._remote_chown.return_value = {
+ 'rc': 0,
+ 'stdout': '',
+ 'stderr': '',
+ }
+ assertSuccess()
+ action_base._remote_chown.return_value = {
+ 'rc': 1,
+ 'stdout': '',
+ 'stderr': '',
+ }
+ remote_user = 'root'
+ action_base._get_admin_users = MagicMock()
+ action_base._get_admin_users.return_value = ['root']
+ assertThrowRegex('user would be unable to read the file.')
+ remote_user = 'remoteuser1'
+
+ # Step 3d: chmod +a on osx
+ assertSuccess()
+ action_base._remote_chmod.assert_called_with(
+ ['remoteuser2 allow read'] + remote_paths,
+ '+a')
+
+ # This case can cause Solaris chmod to return 5 which the ssh plugin
+ # treats as failure. To prevent a regression and ensure we still try the
+ # rest of the cases below, we mock the thrown exception here.
+ # This function ensures that only the macOS case (+a) throws this.
+ def raise_if_plus_a(definitely_not_underscore, mode):
+ if mode == '+a':
+ raise AnsibleAuthenticationFailure()
+ return {'rc': 0, 'stdout': '', 'stderr': ''}
+
+ action_base._remote_chmod.side_effect = raise_if_plus_a
+ assertSuccess()
+
+ # Step 3e: chmod A+ on Solaris
+ # We threw AnsibleAuthenticationFailure above, try Solaris fallback.
+ # Based on our lambda above, it should be successful.
+ action_base._remote_chmod.assert_called_with(
+ remote_paths,
+ 'A+user:remoteuser2:r:allow')
+ assertSuccess()
+
+ # Step 3f: Common group
+ def rc_1_if_chmod_acl(definitely_not_underscore, mode):
+ rc = 0
+ if mode in CHMOD_ACL_FLAGS:
+ rc = 1
+ return {'rc': rc, 'stdout': '', 'stderr': ''}
+
+ action_base._remote_chmod = MagicMock()
+ action_base._remote_chmod.side_effect = rc_1_if_chmod_acl
+
+ get_shell_option = action_base.get_shell_option
+ action_base.get_shell_option = MagicMock()
+ action_base.get_shell_option.side_effect = get_shell_option_for_arg(
+ {
+ 'common_remote_group': 'commongroup',
+ },
+ None)
+ action_base._remote_chgrp = MagicMock()
+ action_base._remote_chgrp.return_value = {
+ 'rc': 0,
+ 'stdout': '',
+ 'stderr': '',
+ }
+ # TODO: Add test to assert warning is shown if
+ # world_readable_temp is set in this case.
+ assertSuccess()
+ action_base._remote_chgrp.assert_called_once_with(
+ remote_paths,
+ 'commongroup')
+
+ # Step 4: world-readable tmpdir
+ action_base.get_shell_option.side_effect = get_shell_option_for_arg(
+ {
+ 'world_readable_temp': True,
+ 'common_remote_group': None,
+ },
+ None)
+ action_base._remote_chmod.return_value = {
+ 'rc': 0,
+ 'stdout': 'some stuff here',
+ 'stderr': '',
+ }
+ assertSuccess()
+ action_base._remote_chmod = MagicMock()
+ action_base._remote_chmod.return_value = {
+ 'rc': 1,
+ 'stdout': 'some stuff here',
+ 'stderr': '',
+ }
+ assertThrowRegex('Failed to set file mode on remote files')
+
+ # Otherwise if we make it here in this state, we hit the catch-all
+ action_base.get_shell_option.side_effect = get_shell_option_for_arg(
+ {},
+ None)
+ assertThrowRegex('on the temporary files Ansible needs to create')
+
+ def test_action_base__remove_tmp_path(self):
+ # create our fake task
+ mock_task = MagicMock()
+
+ # create a mock connection, so we don't actually try and connect to things
+ mock_connection = MagicMock()
+ mock_connection._shell.remove.return_value = 'rm some stuff'
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+
+ action_base._low_level_execute_command = MagicMock()
+ # these don't really return anything or raise errors, so
+ # we're pretty much calling these for coverage right now
+ action_base._remove_tmp_path('/bad/path/dont/remove')
+ action_base._remove_tmp_path('/good/path/to/ansible-tmp-thing')
+
+ @patch('os.unlink')
+ @patch('os.fdopen')
+ @patch('tempfile.mkstemp')
+ def test_action_base__transfer_data(self, mock_mkstemp, mock_fdopen, mock_unlink):
+ # create our fake task
+ mock_task = MagicMock()
+
+ # create a mock connection, so we don't actually try and connect to things
+ mock_connection = MagicMock()
+ mock_connection.put_file.return_value = None
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+
+ mock_afd = MagicMock()
+ mock_afile = MagicMock()
+ mock_mkstemp.return_value = (mock_afd, mock_afile)
+
+ mock_unlink.return_value = None
+
+ mock_afo = MagicMock()
+ mock_afo.write.return_value = None
+ mock_afo.flush.return_value = None
+ mock_afo.close.return_value = None
+ mock_fdopen.return_value = mock_afo
+
+ self.assertEqual(action_base._transfer_data('/path/to/remote/file', 'some data'), '/path/to/remote/file')
+ self.assertEqual(action_base._transfer_data('/path/to/remote/file', 'some mixed data: fö〩'), '/path/to/remote/file')
+ self.assertEqual(action_base._transfer_data('/path/to/remote/file', dict(some_key='some value')), '/path/to/remote/file')
+ self.assertEqual(action_base._transfer_data('/path/to/remote/file', dict(some_key='fö〩')), '/path/to/remote/file')
+
+ mock_afo.write.side_effect = Exception()
+ self.assertRaises(AnsibleError, action_base._transfer_data, '/path/to/remote/file', '')
+
+ def test_action_base__execute_remote_stat(self):
+ # create our fake task
+ mock_task = MagicMock()
+
+ # create a mock connection, so we don't actually try and connect to things
+ mock_connection = MagicMock()
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+
+ action_base._execute_module = MagicMock()
+
+ # test normal case
+ action_base._execute_module.return_value = dict(stat=dict(checksum='1111111111111111111111111111111111', exists=True))
+ res = action_base._execute_remote_stat(path='/path/to/file', all_vars=dict(), follow=False)
+ self.assertEqual(res['checksum'], '1111111111111111111111111111111111')
+
+ # test does not exist
+ action_base._execute_module.return_value = dict(stat=dict(exists=False))
+ res = action_base._execute_remote_stat(path='/path/to/file', all_vars=dict(), follow=False)
+ self.assertFalse(res['exists'])
+ self.assertEqual(res['checksum'], '1')
+
+ # test no checksum in result from _execute_module
+ action_base._execute_module.return_value = dict(stat=dict(exists=True))
+ res = action_base._execute_remote_stat(path='/path/to/file', all_vars=dict(), follow=False)
+ self.assertTrue(res['exists'])
+ self.assertEqual(res['checksum'], '')
+
+ # test stat call failed
+ action_base._execute_module.return_value = dict(failed=True, msg="because I said so")
+ self.assertRaises(AnsibleError, action_base._execute_remote_stat, path='/path/to/file', all_vars=dict(), follow=False)
+
+ def test_action_base__execute_module(self):
+ # create our fake task
+ mock_task = MagicMock()
+ mock_task.action = 'copy'
+ mock_task.args = dict(a=1, b=2, c=3)
+ mock_task.diff = False
+ mock_task.check_mode = False
+ mock_task.no_log = False
+
+ # create a mock connection, so we don't actually try and connect to things
+ def build_module_command(env_string, shebang, cmd, arg_path=None):
+ to_run = [env_string, cmd]
+ if arg_path:
+ to_run.append(arg_path)
+ return " ".join(to_run)
+
+ def get_option(option):
+ return {'admin_users': ['root', 'toor']}.get(option)
+
+ mock_connection = MagicMock()
+ mock_connection.build_module_command.side_effect = build_module_command
+ mock_connection.socket_path = None
+ mock_connection._shell.get_remote_filename.return_value = 'copy.py'
+ mock_connection._shell.join_path.side_effect = os.path.join
+ mock_connection._shell.tmpdir = '/var/tmp/mytempdir'
+ mock_connection._shell.get_option = get_option
+
+ # we're using a real play context here
+ play_context = PlayContext()
+
+ # our test class
+ action_base = DerivedActionBase(
+ task=mock_task,
+ connection=mock_connection,
+ play_context=play_context,
+ loader=None,
+ templar=None,
+ shared_loader_obj=None,
+ )
+
+ # fake a lot of methods as we test those elsewhere
+ action_base._configure_module = MagicMock()
+ action_base._supports_check_mode = MagicMock()
+ action_base._is_pipelining_enabled = MagicMock()
+ action_base._make_tmp_path = MagicMock()
+ action_base._transfer_data = MagicMock()
+ action_base._compute_environment_string = MagicMock()
+ action_base._low_level_execute_command = MagicMock()
+ action_base._fixup_perms2 = MagicMock()
+
+ action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', 'path')
+ action_base._is_pipelining_enabled.return_value = False
+ action_base._compute_environment_string.return_value = ''
+ action_base._connection.has_pipelining = False
+ action_base._make_tmp_path.return_value = '/the/tmp/path'
+ action_base._low_level_execute_command.return_value = dict(stdout='{"rc": 0, "stdout": "ok"}')
+ self.assertEqual(action_base._execute_module(module_name=None, module_args=None), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok']))
+ self.assertEqual(
+ action_base._execute_module(
+ module_name='foo',
+ module_args=dict(z=9, y=8, x=7),
+ task_vars=dict(a=1)
+ ),
+ dict(
+ _ansible_parsed=True,
+ rc=0,
+ stdout="ok",
+ stdout_lines=['ok'],
+ )
+ )
+
+ # test with needing/removing a remote tmp path
+ action_base._configure_module.return_value = ('old', '#!/usr/bin/python', 'this is the module data', 'path')
+ action_base._is_pipelining_enabled.return_value = False
+ action_base._make_tmp_path.return_value = '/the/tmp/path'
+ self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok']))
+
+ action_base._configure_module.return_value = ('non_native_want_json', '#!/usr/bin/python', 'this is the module data', 'path')
+ self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok']))
+
+ play_context.become = True
+ play_context.become_user = 'foo'
+ mock_task.become = True
+ mock_task.become_user = True
+ self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok']))
+
+ # test an invalid shebang return
+ action_base._configure_module.return_value = ('new', '', 'this is the module data', 'path')
+ action_base._is_pipelining_enabled.return_value = False
+ action_base._make_tmp_path.return_value = '/the/tmp/path'
+ self.assertRaises(AnsibleError, action_base._execute_module)
+
+ # test with check mode enabled, once with support for check
+ # mode and once with support disabled to raise an error
+ play_context.check_mode = True
+ mock_task.check_mode = True
+ action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', 'path')
+ self.assertEqual(action_base._execute_module(), dict(_ansible_parsed=True, rc=0, stdout="ok", stdout_lines=['ok']))
+ action_base._supports_check_mode = False
+ self.assertRaises(AnsibleError, action_base._execute_module)
+
+ def test_action_base_sudo_only_if_user_differs(self):
+ fake_loader = MagicMock()
+ fake_loader.get_basedir.return_value = os.getcwd()
+ play_context = PlayContext()
+
+ action_base = DerivedActionBase(None, None, play_context, fake_loader, None, None)
+ action_base.get_become_option = MagicMock(return_value='root')
+ action_base._get_remote_user = MagicMock(return_value='root')
+
+ action_base._connection = MagicMock(exec_command=MagicMock(return_value=(0, '', '')))
+
+ action_base._connection._shell = shell = MagicMock(append_command=MagicMock(return_value=('JOINED CMD')))
+
+ action_base._connection.become = become = MagicMock()
+ become.build_become_command.return_value = 'foo'
+
+ action_base._low_level_execute_command('ECHO', sudoable=True)
+ become.build_become_command.assert_not_called()
+
+ action_base._get_remote_user.return_value = 'apo'
+ action_base._low_level_execute_command('ECHO', sudoable=True, executable='/bin/csh')
+ become.build_become_command.assert_called_once_with("ECHO", shell)
+
+ become.build_become_command.reset_mock()
+
+ with patch.object(C, 'BECOME_ALLOW_SAME_USER', new=True):
+ action_base._get_remote_user.return_value = 'root'
+ action_base._low_level_execute_command('ECHO SAME', sudoable=True)
+ become.build_become_command.assert_called_once_with("ECHO SAME", shell)
+
+ def test__remote_expand_user_relative_pathing(self):
+ action_base = _action_base()
+ action_base._play_context.remote_addr = 'bar'
+ action_base._connection.get_option.return_value = 'bar'
+ action_base._low_level_execute_command = MagicMock(return_value={'stdout': b'../home/user'})
+ action_base._connection._shell.join_path.return_value = '../home/user/foo'
+ with self.assertRaises(AnsibleError) as cm:
+ action_base._remote_expand_user('~/foo')
+ self.assertEqual(
+ cm.exception.message,
+ "'bar' returned an invalid relative home directory path containing '..'"
+ )
+
+
+class TestActionBaseCleanReturnedData(unittest.TestCase):
+ def test(self):
+
+ fake_loader = DictDataLoader({
+ })
+ mock_module_loader = MagicMock()
+ mock_shared_loader_obj = MagicMock()
+ mock_shared_loader_obj.module_loader = mock_module_loader
+ connection_loader_paths = ['/tmp/asdfadf', '/usr/lib64/whatever',
+ 'dfadfasf',
+ 'foo.py',
+ '.*',
+ # FIXME: a path with parans breaks the regex
+ # '(.*)',
+ '/path/to/ansible/lib/ansible/plugins/connection/custom_connection.py',
+ '/path/to/ansible/lib/ansible/plugins/connection/ssh.py']
+
+ def fake_all(path_only=None):
+ for path in connection_loader_paths:
+ yield path
+
+ mock_connection_loader = MagicMock()
+ mock_connection_loader.all = fake_all
+
+ mock_shared_loader_obj.connection_loader = mock_connection_loader
+ mock_connection = MagicMock()
+ # mock_connection._shell.env_prefix.side_effect = env_prefix
+
+ # action_base = DerivedActionBase(mock_task, mock_connection, play_context, None, None, None)
+ action_base = DerivedActionBase(task=None,
+ connection=mock_connection,
+ play_context=None,
+ loader=fake_loader,
+ templar=None,
+ shared_loader_obj=mock_shared_loader_obj)
+ data = {'ansible_playbook_python': '/usr/bin/python',
+ # 'ansible_rsync_path': '/usr/bin/rsync',
+ 'ansible_python_interpreter': '/usr/bin/python',
+ 'ansible_ssh_some_var': 'whatever',
+ 'ansible_ssh_host_key_somehost': 'some key here',
+ 'some_other_var': 'foo bar'}
+ data = clean_facts(data)
+ self.assertNotIn('ansible_playbook_python', data)
+ self.assertNotIn('ansible_python_interpreter', data)
+ self.assertIn('ansible_ssh_host_key_somehost', data)
+ self.assertIn('some_other_var', data)
+
+
+class TestActionBaseParseReturnedData(unittest.TestCase):
+
+ def test_fail_no_json(self):
+ action_base = _action_base()
+ rc = 0
+ stdout = 'foo\nbar\n'
+ err = 'oopsy'
+ returned_data = {'rc': rc,
+ 'stdout': stdout,
+ 'stdout_lines': stdout.splitlines(),
+ 'stderr': err}
+ res = action_base._parse_returned_data(returned_data)
+ self.assertFalse(res['_ansible_parsed'])
+ self.assertTrue(res['failed'])
+ self.assertEqual(res['module_stderr'], err)
+
+ def test_json_empty(self):
+ action_base = _action_base()
+ rc = 0
+ stdout = '{}\n'
+ err = ''
+ returned_data = {'rc': rc,
+ 'stdout': stdout,
+ 'stdout_lines': stdout.splitlines(),
+ 'stderr': err}
+ res = action_base._parse_returned_data(returned_data)
+ del res['_ansible_parsed'] # we always have _ansible_parsed
+ self.assertEqual(len(res), 0)
+ self.assertFalse(res)
+
+ def test_json_facts(self):
+ action_base = _action_base()
+ rc = 0
+ stdout = '{"ansible_facts": {"foo": "bar", "ansible_blip": "blip_value"}}\n'
+ err = ''
+
+ returned_data = {'rc': rc,
+ 'stdout': stdout,
+ 'stdout_lines': stdout.splitlines(),
+ 'stderr': err}
+ res = action_base._parse_returned_data(returned_data)
+ self.assertTrue(res['ansible_facts'])
+ self.assertIn('ansible_blip', res['ansible_facts'])
+ # TODO: Should this be an AnsibleUnsafe?
+ # self.assertIsInstance(res['ansible_facts'], AnsibleUnsafe)
+
+ def test_json_facts_add_host(self):
+ action_base = _action_base()
+ rc = 0
+ stdout = '''{"ansible_facts": {"foo": "bar", "ansible_blip": "blip_value"},
+ "add_host": {"host_vars": {"some_key": ["whatever the add_host object is"]}
+ }
+ }\n'''
+ err = ''
+
+ returned_data = {'rc': rc,
+ 'stdout': stdout,
+ 'stdout_lines': stdout.splitlines(),
+ 'stderr': err}
+ res = action_base._parse_returned_data(returned_data)
+ self.assertTrue(res['ansible_facts'])
+ self.assertIn('ansible_blip', res['ansible_facts'])
+ self.assertIn('add_host', res)
+ # TODO: Should this be an AnsibleUnsafe?
+ # self.assertIsInstance(res['ansible_facts'], AnsibleUnsafe)
diff --git a/test/units/plugins/action/test_gather_facts.py b/test/units/plugins/action/test_gather_facts.py
new file mode 100644
index 0000000..20225aa
--- /dev/null
+++ b/test/units/plugins/action/test_gather_facts.py
@@ -0,0 +1,98 @@
+# (c) 2016, Saran Ahluwalia <ahlusar.ahluwalia@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from units.compat import unittest
+from unittest.mock import MagicMock, patch
+
+from ansible import constants as C
+from ansible.playbook.task import Task
+from ansible.plugins.action.gather_facts import ActionModule as GatherFactsAction
+from ansible.template import Templar
+from ansible.executor import module_common
+
+from units.mock.loader import DictDataLoader
+
+
+class TestNetworkFacts(unittest.TestCase):
+ task = MagicMock(Task)
+ play_context = MagicMock()
+ play_context.check_mode = False
+ connection = MagicMock()
+ fake_loader = DictDataLoader({
+ })
+ templar = Templar(loader=fake_loader)
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ @patch.object(module_common, '_get_collection_metadata', return_value={})
+ def test_network_gather_facts_smart_facts_module(self, mock_collection_metadata):
+ self.fqcn_task_vars = {'ansible_network_os': 'ios'}
+ self.task.action = 'gather_facts'
+ self.task.async_val = False
+ self.task.args = {}
+
+ plugin = GatherFactsAction(self.task, self.connection, self.play_context, loader=None, templar=self.templar, shared_loader_obj=None)
+ get_module_args = MagicMock()
+ plugin._get_module_args = get_module_args
+ plugin._execute_module = MagicMock()
+
+ res = plugin.run(task_vars=self.fqcn_task_vars)
+
+ # assert the gather_facts config is 'smart'
+ facts_modules = C.config.get_config_value('FACTS_MODULES', variables=self.fqcn_task_vars)
+ self.assertEqual(facts_modules, ['smart'])
+
+ # assert the correct module was found
+ self.assertEqual(get_module_args.call_count, 1)
+
+ self.assertEqual(
+ get_module_args.call_args.args,
+ ('ansible.legacy.ios_facts', {'ansible_network_os': 'ios'},)
+ )
+
+ @patch.object(module_common, '_get_collection_metadata', return_value={})
+ def test_network_gather_facts_smart_facts_module_fqcn(self, mock_collection_metadata):
+ self.fqcn_task_vars = {'ansible_network_os': 'cisco.ios.ios'}
+ self.task.action = 'gather_facts'
+ self.task.async_val = False
+ self.task.args = {}
+
+ plugin = GatherFactsAction(self.task, self.connection, self.play_context, loader=None, templar=self.templar, shared_loader_obj=None)
+ get_module_args = MagicMock()
+ plugin._get_module_args = get_module_args
+ plugin._execute_module = MagicMock()
+
+ res = plugin.run(task_vars=self.fqcn_task_vars)
+
+ # assert the gather_facts config is 'smart'
+ facts_modules = C.config.get_config_value('FACTS_MODULES', variables=self.fqcn_task_vars)
+ self.assertEqual(facts_modules, ['smart'])
+
+ # assert the correct module was found
+ self.assertEqual(get_module_args.call_count, 1)
+
+ self.assertEqual(
+ get_module_args.call_args.args,
+ ('cisco.ios.ios_facts', {'ansible_network_os': 'cisco.ios.ios'},)
+ )
diff --git a/test/units/plugins/action/test_pause.py b/test/units/plugins/action/test_pause.py
new file mode 100644
index 0000000..8ad6db7
--- /dev/null
+++ b/test/units/plugins/action/test_pause.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2021 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+import curses
+import importlib
+import io
+import pytest
+import sys
+
+from ansible.plugins.action import pause # noqa: F401
+from ansible.module_utils.six import PY2
+
+builtin_import = 'builtins.__import__'
+if PY2:
+ builtin_import = '__builtin__.__import__'
+
+
+def test_pause_curses_tigetstr_none(mocker, monkeypatch):
+ monkeypatch.delitem(sys.modules, 'ansible.plugins.action.pause')
+
+ dunder_import = __import__
+
+ def _import(*args, **kwargs):
+ if args[0] == 'curses':
+ mock_curses = mocker.Mock()
+ mock_curses.setupterm = mocker.Mock(return_value=True)
+ mock_curses.tigetstr = mocker.Mock(return_value=None)
+ return mock_curses
+ else:
+ return dunder_import(*args, **kwargs)
+
+ mocker.patch(builtin_import, _import)
+
+ mod = importlib.import_module('ansible.plugins.action.pause')
+
+ assert mod.HAS_CURSES is True
+ assert mod.MOVE_TO_BOL == b'\r'
+ assert mod.CLEAR_TO_EOL == b'\x1b[K'
+
+
+def test_pause_missing_curses(mocker, monkeypatch):
+ monkeypatch.delitem(sys.modules, 'ansible.plugins.action.pause')
+
+ dunder_import = __import__
+
+ def _import(*args, **kwargs):
+ if args[0] == 'curses':
+ raise ImportError
+ else:
+ return dunder_import(*args, **kwargs)
+
+ mocker.patch(builtin_import, _import)
+
+ mod = importlib.import_module('ansible.plugins.action.pause')
+
+ with pytest.raises(AttributeError):
+ mod.curses
+
+ assert mod.HAS_CURSES is False
+ assert mod.MOVE_TO_BOL == b'\r'
+ assert mod.CLEAR_TO_EOL == b'\x1b[K'
+
+
+@pytest.mark.parametrize('exc', (curses.error, TypeError, io.UnsupportedOperation))
+def test_pause_curses_setupterm_error(mocker, monkeypatch, exc):
+ monkeypatch.delitem(sys.modules, 'ansible.plugins.action.pause')
+
+ dunder_import = __import__
+
+ def _import(*args, **kwargs):
+ if args[0] == 'curses':
+ mock_curses = mocker.Mock()
+ mock_curses.setupterm = mocker.Mock(side_effect=exc)
+ mock_curses.error = curses.error
+ return mock_curses
+ else:
+ return dunder_import(*args, **kwargs)
+
+ mocker.patch(builtin_import, _import)
+
+ mod = importlib.import_module('ansible.plugins.action.pause')
+
+ assert mod.HAS_CURSES is False
+ assert mod.MOVE_TO_BOL == b'\r'
+ assert mod.CLEAR_TO_EOL == b'\x1b[K'
diff --git a/test/units/plugins/action/test_raw.py b/test/units/plugins/action/test_raw.py
new file mode 100644
index 0000000..3348051
--- /dev/null
+++ b/test/units/plugins/action/test_raw.py
@@ -0,0 +1,105 @@
+# (c) 2016, Saran Ahluwalia <ahlusar.ahluwalia@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+
+from ansible.errors import AnsibleActionFail
+from units.compat import unittest
+from unittest.mock import MagicMock, Mock
+from ansible.plugins.action.raw import ActionModule
+from ansible.playbook.task import Task
+from ansible.plugins.loader import connection_loader
+
+
+class TestCopyResultExclude(unittest.TestCase):
+
+ def setUp(self):
+ self.play_context = Mock()
+ self.play_context.shell = 'sh'
+ self.connection = connection_loader.get('local', self.play_context, os.devnull)
+
+ def tearDown(self):
+ pass
+
+ # The current behavior of the raw aciton in regards to executable is currently in question;
+ # the test_raw_executable_is_not_empty_string verifies the current behavior (whether it is desireed or not.
+ # Please refer to the following for context:
+ # Issue: https://github.com/ansible/ansible/issues/16054
+ # PR: https://github.com/ansible/ansible/pull/16085
+
+ def test_raw_executable_is_not_empty_string(self):
+
+ task = MagicMock(Task)
+ task.async_val = False
+
+ task.args = {'_raw_params': 'Args1'}
+ self.play_context.check_mode = False
+
+ self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None)
+ self.mock_am._low_level_execute_command = Mock(return_value={})
+ self.mock_am.display = Mock()
+ self.mock_am._admin_users = ['root', 'toor']
+
+ self.mock_am.run()
+ self.mock_am._low_level_execute_command.assert_called_with('Args1', executable=False)
+
+ def test_raw_check_mode_is_True(self):
+
+ task = MagicMock(Task)
+ task.async_val = False
+
+ task.args = {'_raw_params': 'Args1'}
+ self.play_context.check_mode = True
+
+ try:
+ self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None)
+ except AnsibleActionFail:
+ pass
+
+ def test_raw_test_environment_is_None(self):
+
+ task = MagicMock(Task)
+ task.async_val = False
+
+ task.args = {'_raw_params': 'Args1'}
+ task.environment = None
+ self.play_context.check_mode = False
+
+ self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None)
+ self.mock_am._low_level_execute_command = Mock(return_value={})
+ self.mock_am.display = Mock()
+
+ self.assertEqual(task.environment, None)
+
+ def test_raw_task_vars_is_not_None(self):
+
+ task = MagicMock(Task)
+ task.async_val = False
+
+ task.args = {'_raw_params': 'Args1'}
+ task.environment = None
+ self.play_context.check_mode = False
+
+ self.mock_am = ActionModule(task, self.connection, self.play_context, loader=None, templar=None, shared_loader_obj=None)
+ self.mock_am._low_level_execute_command = Mock(return_value={})
+ self.mock_am.display = Mock()
+
+ self.mock_am.run(task_vars={'a': 'b'})
+ self.assertEqual(task.environment, None)
diff --git a/test/units/plugins/become/__init__.py b/test/units/plugins/become/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/become/__init__.py
diff --git a/test/units/plugins/become/conftest.py b/test/units/plugins/become/conftest.py
new file mode 100644
index 0000000..a04a5e2
--- /dev/null
+++ b/test/units/plugins/become/conftest.py
@@ -0,0 +1,37 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+# (c) 2017 Ansible Project
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+
+from ansible.cli.arguments import option_helpers as opt_help
+from ansible.utils import context_objects as co
+
+
+@pytest.fixture
+def parser():
+ parser = opt_help.create_base_parser('testparser')
+
+ opt_help.add_runas_options(parser)
+ opt_help.add_meta_options(parser)
+ opt_help.add_runtask_options(parser)
+ opt_help.add_vault_options(parser)
+ opt_help.add_async_options(parser)
+ opt_help.add_connect_options(parser)
+ opt_help.add_subset_options(parser)
+ opt_help.add_check_options(parser)
+ opt_help.add_inventory_options(parser)
+
+ return parser
+
+
+@pytest.fixture
+def reset_cli_args():
+ co.GlobalCLIArgs._Singleton__instance = None
+ yield
+ co.GlobalCLIArgs._Singleton__instance = None
diff --git a/test/units/plugins/become/test_su.py b/test/units/plugins/become/test_su.py
new file mode 100644
index 0000000..bf74a4c
--- /dev/null
+++ b/test/units/plugins/become/test_su.py
@@ -0,0 +1,30 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+# (c) 2020 Ansible Project
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import re
+
+from ansible import context
+from ansible.plugins.loader import become_loader, shell_loader
+
+
+def test_su(mocker, parser, reset_cli_args):
+ options = parser.parse_args([])
+ context._init_global_context(options)
+
+ su = become_loader.get('su')
+ sh = shell_loader.get('sh')
+ sh.executable = "/bin/bash"
+
+ su.set_options(direct={
+ 'become_user': 'foo',
+ 'become_flags': '',
+ })
+
+ cmd = su.build_become_command('/bin/foo', sh)
+ assert re.match(r"""su\s+foo -c '/bin/bash -c '"'"'echo BECOME-SUCCESS-.+?; /bin/foo'"'"''""", cmd)
diff --git a/test/units/plugins/become/test_sudo.py b/test/units/plugins/become/test_sudo.py
new file mode 100644
index 0000000..67eb9a4
--- /dev/null
+++ b/test/units/plugins/become/test_sudo.py
@@ -0,0 +1,67 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+# (c) 2020 Ansible Project
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import re
+
+from ansible import context
+from ansible.plugins.loader import become_loader, shell_loader
+
+
+def test_sudo(mocker, parser, reset_cli_args):
+ options = parser.parse_args([])
+ context._init_global_context(options)
+
+ sudo = become_loader.get('sudo')
+ sh = shell_loader.get('sh')
+ sh.executable = "/bin/bash"
+
+ sudo.set_options(direct={
+ 'become_user': 'foo',
+ 'become_flags': '-n -s -H',
+ })
+
+ cmd = sudo.build_become_command('/bin/foo', sh)
+
+ assert re.match(r"""sudo\s+-n -s -H\s+-u foo /bin/bash -c 'echo BECOME-SUCCESS-.+? ; /bin/foo'""", cmd), cmd
+
+ sudo.set_options(direct={
+ 'become_user': 'foo',
+ 'become_flags': '-n -s -H',
+ 'become_pass': 'testpass',
+ })
+
+ cmd = sudo.build_become_command('/bin/foo', sh)
+ assert re.match(r"""sudo\s+-s\s-H\s+-p "\[sudo via ansible, key=.+?\] password:" -u foo /bin/bash -c 'echo BECOME-SUCCESS-.+? ; /bin/foo'""", cmd), cmd
+
+ sudo.set_options(direct={
+ 'become_user': 'foo',
+ 'become_flags': '-snH',
+ 'become_pass': 'testpass',
+ })
+
+ cmd = sudo.build_become_command('/bin/foo', sh)
+ assert re.match(r"""sudo\s+-sH\s+-p "\[sudo via ansible, key=.+?\] password:" -u foo /bin/bash -c 'echo BECOME-SUCCESS-.+? ; /bin/foo'""", cmd), cmd
+
+ sudo.set_options(direct={
+ 'become_user': 'foo',
+ 'become_flags': '--non-interactive -s -H',
+ 'become_pass': 'testpass',
+ })
+
+ cmd = sudo.build_become_command('/bin/foo', sh)
+ assert re.match(r"""sudo\s+-s\s-H\s+-p "\[sudo via ansible, key=.+?\] password:" -u foo /bin/bash -c 'echo BECOME-SUCCESS-.+? ; /bin/foo'""", cmd), cmd
+
+ sudo.set_options(direct={
+ 'become_user': 'foo',
+ 'become_flags': '--non-interactive -nC5 -s -H',
+ 'become_pass': 'testpass',
+ })
+
+ cmd = sudo.build_become_command('/bin/foo', sh)
+ assert re.match(r"""sudo\s+-C5\s-s\s-H\s+-p "\[sudo via ansible, key=.+?\] password:" -u foo /bin/bash -c 'echo BECOME-SUCCESS-.+? ; /bin/foo'""", cmd), cmd
diff --git a/test/units/plugins/cache/__init__.py b/test/units/plugins/cache/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/cache/__init__.py
diff --git a/test/units/plugins/cache/test_cache.py b/test/units/plugins/cache/test_cache.py
new file mode 100644
index 0000000..25b84c0
--- /dev/null
+++ b/test/units/plugins/cache/test_cache.py
@@ -0,0 +1,199 @@
+# (c) 2012-2015, Michael DeHaan <michael.dehaan@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+import shutil
+import tempfile
+
+from unittest import mock
+
+from units.compat import unittest
+from ansible.errors import AnsibleError
+from ansible.plugins.cache import CachePluginAdjudicator
+from ansible.plugins.cache.memory import CacheModule as MemoryCache
+from ansible.plugins.loader import cache_loader
+from ansible.vars.fact_cache import FactCache
+
+import pytest
+
+
+class TestCachePluginAdjudicator(unittest.TestCase):
+ def setUp(self):
+ # memory plugin cache
+ self.cache = CachePluginAdjudicator()
+ self.cache['cache_key'] = {'key1': 'value1', 'key2': 'value2'}
+ self.cache['cache_key_2'] = {'key': 'value'}
+
+ def test___setitem__(self):
+ self.cache['new_cache_key'] = {'new_key1': ['new_value1', 'new_value2']}
+ assert self.cache['new_cache_key'] == {'new_key1': ['new_value1', 'new_value2']}
+
+ def test_inner___setitem__(self):
+ self.cache['new_cache_key'] = {'new_key1': ['new_value1', 'new_value2']}
+ self.cache['new_cache_key']['new_key1'][0] = 'updated_value1'
+ assert self.cache['new_cache_key'] == {'new_key1': ['updated_value1', 'new_value2']}
+
+ def test___contains__(self):
+ assert 'cache_key' in self.cache
+ assert 'not_cache_key' not in self.cache
+
+ def test_get(self):
+ assert self.cache.get('cache_key') == {'key1': 'value1', 'key2': 'value2'}
+
+ def test_get_with_default(self):
+ assert self.cache.get('foo', 'bar') == 'bar'
+
+ def test_get_without_default(self):
+ assert self.cache.get('foo') is None
+
+ def test___getitem__(self):
+ with pytest.raises(KeyError):
+ self.cache['foo']
+
+ def test_pop_with_default(self):
+ assert self.cache.pop('foo', 'bar') == 'bar'
+
+ def test_pop_without_default(self):
+ with pytest.raises(KeyError):
+ assert self.cache.pop('foo')
+
+ def test_pop(self):
+ v = self.cache.pop('cache_key_2')
+ assert v == {'key': 'value'}
+ assert 'cache_key_2' not in self.cache
+
+ def test_update(self):
+ self.cache.update({'cache_key': {'key2': 'updatedvalue'}})
+ assert self.cache['cache_key']['key2'] == 'updatedvalue'
+
+ def test_update_cache_if_changed(self):
+ # Changes are stored in the CachePluginAdjudicator and will be
+ # persisted to the plugin when calling update_cache_if_changed()
+ # The exception is flush which flushes the plugin immediately.
+ assert len(self.cache.keys()) == 2
+ assert len(self.cache._plugin.keys()) == 0
+ self.cache.update_cache_if_changed()
+ assert len(self.cache._plugin.keys()) == 2
+
+ def test_flush(self):
+ # Fake that the cache already has some data in it but the adjudicator
+ # hasn't loaded it in.
+ self.cache._plugin.set('monkey', 'animal')
+ self.cache._plugin.set('wolf', 'animal')
+ self.cache._plugin.set('another wolf', 'another animal')
+
+ # The adjudicator does't know about the new entries
+ assert len(self.cache.keys()) == 2
+ # But the cache itself does
+ assert len(self.cache._plugin.keys()) == 3
+
+ # If we call flush, both the adjudicator and the cache should flush
+ self.cache.flush()
+ assert len(self.cache.keys()) == 0
+ assert len(self.cache._plugin.keys()) == 0
+
+
+class TestJsonFileCache(TestCachePluginAdjudicator):
+ cache_prefix = ''
+
+ def setUp(self):
+ self.cache_dir = tempfile.mkdtemp(prefix='ansible-plugins-cache-')
+ self.cache = CachePluginAdjudicator(
+ plugin_name='jsonfile', _uri=self.cache_dir,
+ _prefix=self.cache_prefix)
+ self.cache['cache_key'] = {'key1': 'value1', 'key2': 'value2'}
+ self.cache['cache_key_2'] = {'key': 'value'}
+
+ def test_keys(self):
+ # A cache without a prefix will consider all files in the cache
+ # directory as valid cache entries.
+ self.cache._plugin._dump(
+ 'no prefix', os.path.join(self.cache_dir, 'no_prefix'))
+ self.cache._plugin._dump(
+ 'special cache', os.path.join(self.cache_dir, 'special_test'))
+
+ # The plugin does not know the CachePluginAdjudicator entries.
+ assert sorted(self.cache._plugin.keys()) == [
+ 'no_prefix', 'special_test']
+
+ assert 'no_prefix' in self.cache
+ assert 'special_test' in self.cache
+ assert 'test' not in self.cache
+ assert self.cache['no_prefix'] == 'no prefix'
+ assert self.cache['special_test'] == 'special cache'
+
+ def tearDown(self):
+ shutil.rmtree(self.cache_dir)
+
+
+class TestJsonFileCachePrefix(TestJsonFileCache):
+ cache_prefix = 'special_'
+
+ def test_keys(self):
+ # For caches with a prefix only files that match the prefix are
+ # considered. The prefix is removed from the key name.
+ self.cache._plugin._dump(
+ 'no prefix', os.path.join(self.cache_dir, 'no_prefix'))
+ self.cache._plugin._dump(
+ 'special cache', os.path.join(self.cache_dir, 'special_test'))
+
+ # The plugin does not know the CachePluginAdjudicator entries.
+ assert sorted(self.cache._plugin.keys()) == ['test']
+
+ assert 'no_prefix' not in self.cache
+ assert 'special_test' not in self.cache
+ assert 'test' in self.cache
+ assert self.cache['test'] == 'special cache'
+
+
+class TestFactCache(unittest.TestCase):
+ def setUp(self):
+ with mock.patch('ansible.constants.CACHE_PLUGIN', 'memory'):
+ self.cache = FactCache()
+
+ def test_copy(self):
+ self.cache['avocado'] = 'fruit'
+ self.cache['daisy'] = 'flower'
+ a_copy = self.cache.copy()
+ self.assertEqual(type(a_copy), dict)
+ self.assertEqual(a_copy, dict(avocado='fruit', daisy='flower'))
+
+ def test_flush(self):
+ self.cache['motorcycle'] = 'vehicle'
+ self.cache['sock'] = 'clothing'
+ self.cache.flush()
+ assert len(self.cache.keys()) == 0
+
+ def test_plugin_load_failure(self):
+ # See https://github.com/ansible/ansible/issues/18751
+ # Note no fact_connection config set, so this will fail
+ with mock.patch('ansible.constants.CACHE_PLUGIN', 'json'):
+ self.assertRaisesRegex(AnsibleError,
+ "Unable to load the facts cache plugin.*json.*",
+ FactCache)
+
+ def test_update(self):
+ self.cache.update({'cache_key': {'key2': 'updatedvalue'}})
+ assert self.cache['cache_key']['key2'] == 'updatedvalue'
+
+
+def test_memory_cachemodule_with_loader():
+ assert isinstance(cache_loader.get('memory'), MemoryCache)
diff --git a/test/units/plugins/callback/__init__.py b/test/units/plugins/callback/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/callback/__init__.py
diff --git a/test/units/plugins/callback/test_callback.py b/test/units/plugins/callback/test_callback.py
new file mode 100644
index 0000000..ccfa465
--- /dev/null
+++ b/test/units/plugins/callback/test_callback.py
@@ -0,0 +1,416 @@
+# (c) 2012-2014, Chris Meyers <chris.meyers.fsu@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import json
+import re
+import textwrap
+import types
+
+from units.compat import unittest
+from unittest.mock import MagicMock
+
+from ansible.executor.task_result import TaskResult
+from ansible.inventory.host import Host
+from ansible.plugins.callback import CallbackBase
+
+
+mock_task = MagicMock()
+mock_task.delegate_to = None
+
+
+class TestCallback(unittest.TestCase):
+ # FIXME: This doesn't really test anything...
+ def test_init(self):
+ CallbackBase()
+
+ def test_display(self):
+ display_mock = MagicMock()
+ display_mock.verbosity = 0
+ cb = CallbackBase(display=display_mock)
+ self.assertIs(cb._display, display_mock)
+
+ def test_display_verbose(self):
+ display_mock = MagicMock()
+ display_mock.verbosity = 5
+ cb = CallbackBase(display=display_mock)
+ self.assertIs(cb._display, display_mock)
+
+ def test_host_label(self):
+ result = TaskResult(host=Host('host1'), task=mock_task, return_data={})
+
+ self.assertEqual(CallbackBase.host_label(result), 'host1')
+
+ def test_host_label_delegated(self):
+ mock_task.delegate_to = 'host2'
+ result = TaskResult(
+ host=Host('host1'),
+ task=mock_task,
+ return_data={'_ansible_delegated_vars': {'ansible_host': 'host2'}},
+ )
+ self.assertEqual(CallbackBase.host_label(result), 'host1 -> host2')
+
+ # TODO: import callback module so we can patch callback.cli/callback.C
+
+
+class TestCallbackResults(unittest.TestCase):
+
+ def test_get_item_label(self):
+ cb = CallbackBase()
+ results = {'item': 'some_item'}
+ res = cb._get_item_label(results)
+ self.assertEqual(res, 'some_item')
+
+ def test_get_item_label_no_log(self):
+ cb = CallbackBase()
+ results = {'item': 'some_item', '_ansible_no_log': True}
+ res = cb._get_item_label(results)
+ self.assertEqual(res, "(censored due to no_log)")
+
+ results = {'item': 'some_item', '_ansible_no_log': False}
+ res = cb._get_item_label(results)
+ self.assertEqual(res, "some_item")
+
+ def test_clean_results_debug_task(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item',
+ 'invocation': 'foo --bar whatever [some_json]',
+ 'a': 'a single a in result note letter a is in invocation',
+ 'b': 'a single b in result note letter b is not in invocation',
+ 'changed': True}
+
+ cb._clean_results(result, 'debug')
+
+ # See https://github.com/ansible/ansible/issues/33723
+ self.assertTrue('a' in result)
+ self.assertTrue('b' in result)
+ self.assertFalse('invocation' in result)
+ self.assertFalse('changed' in result)
+
+ def test_clean_results_debug_task_no_invocation(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item',
+ 'a': 'a single a in result note letter a is in invocation',
+ 'b': 'a single b in result note letter b is not in invocation',
+ 'changed': True}
+
+ cb._clean_results(result, 'debug')
+ self.assertTrue('a' in result)
+ self.assertTrue('b' in result)
+ self.assertFalse('changed' in result)
+ self.assertFalse('invocation' in result)
+
+ def test_clean_results_debug_task_empty_results(self):
+ cb = CallbackBase()
+ result = {}
+ cb._clean_results(result, 'debug')
+ self.assertFalse('invocation' in result)
+ self.assertEqual(len(result), 0)
+
+ def test_clean_results(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item',
+ 'invocation': 'foo --bar whatever [some_json]',
+ 'a': 'a single a in result note letter a is in invocation',
+ 'b': 'a single b in result note letter b is not in invocation',
+ 'changed': True}
+
+ expected_result = result.copy()
+ cb._clean_results(result, 'ebug')
+ self.assertEqual(result, expected_result)
+
+
+class TestCallbackDumpResults(object):
+ def test_internal_keys(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item',
+ '_ansible_some_var': 'SENTINEL',
+ 'testing_ansible_out': 'should_be_left_in LEFTIN',
+ 'invocation': 'foo --bar whatever [some_json]',
+ 'some_dict_key': {'a_sub_dict_for_key': 'baz'},
+ 'bad_dict_key': {'_ansible_internal_blah': 'SENTINEL'},
+ 'changed': True}
+ json_out = cb._dump_results(result)
+ assert '"_ansible_' not in json_out
+ assert 'SENTINEL' not in json_out
+ assert 'LEFTIN' in json_out
+
+ def test_exception(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item LEFTIN',
+ 'exception': ['frame1', 'SENTINEL']}
+ json_out = cb._dump_results(result)
+ assert 'SENTINEL' not in json_out
+ assert 'exception' not in json_out
+ assert 'LEFTIN' in json_out
+
+ def test_verbose(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item LEFTIN',
+ '_ansible_verbose_always': 'chicane'}
+ json_out = cb._dump_results(result)
+ assert 'SENTINEL' not in json_out
+ assert 'LEFTIN' in json_out
+
+ def test_diff(self):
+ cb = CallbackBase()
+ result = {'item': 'some_item LEFTIN',
+ 'diff': ['remove stuff', 'added LEFTIN'],
+ '_ansible_verbose_always': 'chicane'}
+ json_out = cb._dump_results(result)
+ assert 'SENTINEL' not in json_out
+ assert 'LEFTIN' in json_out
+
+ def test_mixed_keys(self):
+ cb = CallbackBase()
+ result = {3: 'pi',
+ 'tau': 6}
+ json_out = cb._dump_results(result)
+ round_trip_result = json.loads(json_out)
+ assert len(round_trip_result) == 2
+ assert '3' in round_trip_result
+ assert 'tau' in round_trip_result
+ assert round_trip_result['3'] == 'pi'
+ assert round_trip_result['tau'] == 6
+
+
+class TestCallbackDiff(unittest.TestCase):
+
+ def setUp(self):
+ self.cb = CallbackBase()
+
+ def _strip_color(self, s):
+ return re.sub('\033\\[[^m]*m', '', s)
+
+ def test_difflist(self):
+ # TODO: split into smaller tests?
+ difflist = [{'before': u'preface\nThe Before String\npostscript',
+ 'after': u'preface\nThe After String\npostscript',
+ 'before_header': u'just before',
+ 'after_header': u'just after'
+ },
+ {'before': u'preface\nThe Before String\npostscript',
+ 'after': u'preface\nThe After String\npostscript',
+ },
+ {'src_binary': 'chicane'},
+ {'dst_binary': 'chicanery'},
+ {'dst_larger': 1},
+ {'src_larger': 2},
+ {'prepared': u'what does prepared do?'},
+ {'before_header': u'just before'},
+ {'after_header': u'just after'}]
+
+ res = self.cb._get_diff(difflist)
+
+ self.assertIn(u'Before String', res)
+ self.assertIn(u'After String', res)
+ self.assertIn(u'just before', res)
+ self.assertIn(u'just after', res)
+
+ def test_simple_diff(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': 'one\ntwo\nthree\n',
+ 'after': 'one\nthree\nfour\n',
+ })),
+ textwrap.dedent('''\
+ --- before: somefile.txt
+ +++ after: generated from template somefile.j2
+ @@ -1,3 +1,3 @@
+ one
+ -two
+ three
+ +four
+
+ '''))
+
+ def test_new_file(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': '',
+ 'after': 'one\ntwo\nthree\n',
+ })),
+ textwrap.dedent('''\
+ --- before: somefile.txt
+ +++ after: generated from template somefile.j2
+ @@ -0,0 +1,3 @@
+ +one
+ +two
+ +three
+
+ '''))
+
+ def test_clear_file(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': 'one\ntwo\nthree\n',
+ 'after': '',
+ })),
+ textwrap.dedent('''\
+ --- before: somefile.txt
+ +++ after: generated from template somefile.j2
+ @@ -1,3 +0,0 @@
+ -one
+ -two
+ -three
+
+ '''))
+
+ def test_no_trailing_newline_before(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': 'one\ntwo\nthree',
+ 'after': 'one\ntwo\nthree\n',
+ })),
+ textwrap.dedent('''\
+ --- before: somefile.txt
+ +++ after: generated from template somefile.j2
+ @@ -1,3 +1,3 @@
+ one
+ two
+ -three
+ \\ No newline at end of file
+ +three
+
+ '''))
+
+ def test_no_trailing_newline_after(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': 'one\ntwo\nthree\n',
+ 'after': 'one\ntwo\nthree',
+ })),
+ textwrap.dedent('''\
+ --- before: somefile.txt
+ +++ after: generated from template somefile.j2
+ @@ -1,3 +1,3 @@
+ one
+ two
+ -three
+ +three
+ \\ No newline at end of file
+
+ '''))
+
+ def test_no_trailing_newline_both(self):
+ self.assertMultiLineEqual(
+ self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': 'one\ntwo\nthree',
+ 'after': 'one\ntwo\nthree',
+ }),
+ '')
+
+ def test_no_trailing_newline_both_with_some_changes(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before_header': 'somefile.txt',
+ 'after_header': 'generated from template somefile.j2',
+ 'before': 'one\ntwo\nthree',
+ 'after': 'one\nfive\nthree',
+ })),
+ textwrap.dedent('''\
+ --- before: somefile.txt
+ +++ after: generated from template somefile.j2
+ @@ -1,3 +1,3 @@
+ one
+ -two
+ +five
+ three
+ \\ No newline at end of file
+
+ '''))
+
+ def test_diff_dicts(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before': dict(one=1, two=2, three=3),
+ 'after': dict(one=1, three=3, four=4),
+ })),
+ textwrap.dedent('''\
+ --- before
+ +++ after
+ @@ -1,5 +1,5 @@
+ {
+ + "four": 4,
+ "one": 1,
+ - "three": 3,
+ - "two": 2
+ + "three": 3
+ }
+
+ '''))
+
+ def test_diff_before_none(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before': None,
+ 'after': 'one line\n',
+ })),
+ textwrap.dedent('''\
+ --- before
+ +++ after
+ @@ -0,0 +1 @@
+ +one line
+
+ '''))
+
+ def test_diff_after_none(self):
+ self.assertMultiLineEqual(
+ self._strip_color(self.cb._get_diff({
+ 'before': 'one line\n',
+ 'after': None,
+ })),
+ textwrap.dedent('''\
+ --- before
+ +++ after
+ @@ -1 +0,0 @@
+ -one line
+
+ '''))
+
+
+class TestCallbackOnMethods(unittest.TestCase):
+ def _find_on_methods(self, callback):
+ cb_dir = dir(callback)
+ method_names = [x for x in cb_dir if '_on_' in x]
+ methods = [getattr(callback, mn) for mn in method_names]
+ return methods
+
+ def test_are_methods(self):
+ cb = CallbackBase()
+ for method in self._find_on_methods(cb):
+ self.assertIsInstance(method, types.MethodType)
+
+ def test_on_any(self):
+ cb = CallbackBase()
+ cb.v2_on_any('whatever', some_keyword='blippy')
+ cb.on_any('whatever', some_keyword='blippy')
diff --git a/test/units/plugins/connection/__init__.py b/test/units/plugins/connection/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/connection/__init__.py
diff --git a/test/units/plugins/connection/test_connection.py b/test/units/plugins/connection/test_connection.py
new file mode 100644
index 0000000..38d6691
--- /dev/null
+++ b/test/units/plugins/connection/test_connection.py
@@ -0,0 +1,163 @@
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from io import StringIO
+
+from units.compat import unittest
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.connection import ConnectionBase
+from ansible.plugins.loader import become_loader
+
+
+class TestConnectionBaseClass(unittest.TestCase):
+
+ def setUp(self):
+ self.play_context = PlayContext()
+ self.play_context.prompt = (
+ '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: '
+ )
+ self.in_stream = StringIO()
+
+ def tearDown(self):
+ pass
+
+ def test_subclass_error(self):
+ class ConnectionModule1(ConnectionBase):
+ pass
+ with self.assertRaises(TypeError):
+ ConnectionModule1() # pylint: disable=abstract-class-instantiated
+
+ class ConnectionModule2(ConnectionBase):
+ def get(self, key):
+ super(ConnectionModule2, self).get(key)
+
+ with self.assertRaises(TypeError):
+ ConnectionModule2() # pylint: disable=abstract-class-instantiated
+
+ def test_subclass_success(self):
+ class ConnectionModule3(ConnectionBase):
+
+ @property
+ def transport(self):
+ pass
+
+ def _connect(self):
+ pass
+
+ def exec_command(self):
+ pass
+
+ def put_file(self):
+ pass
+
+ def fetch_file(self):
+ pass
+
+ def close(self):
+ pass
+
+ self.assertIsInstance(ConnectionModule3(self.play_context, self.in_stream), ConnectionModule3)
+
+ def test_check_password_prompt(self):
+ local = (
+ b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n'
+ b'BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq\n'
+ )
+
+ ssh_pipelining_vvvv = b'''
+debug3: mux_master_read_cb: channel 1 packet type 0x10000002 len 251
+debug2: process_mux_new_session: channel 1: request tty 0, X 1, agent 1, subsys 0, term "xterm-256color", cmd "/bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0'", env 0
+debug3: process_mux_new_session: got fds stdin 9, stdout 10, stderr 11
+debug2: client_session2_setup: id 2
+debug1: Sending command: /bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0'
+debug2: channel 2: request exec confirm 1
+debug2: channel 2: rcvd ext data 67
+[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: debug2: channel 2: written 67 to efd 11
+BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq
+debug3: receive packet: type 98
+''' # noqa
+
+ ssh_nopipelining_vvvv = b'''
+debug3: mux_master_read_cb: channel 1 packet type 0x10000002 len 251
+debug2: process_mux_new_session: channel 1: request tty 1, X 1, agent 1, subsys 0, term "xterm-256color", cmd "/bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0'", env 0
+debug3: mux_client_request_session: session request sent
+debug3: send packet: type 98
+debug1: Sending command: /bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq; /bin/true'"'"' && sleep 0'
+debug2: channel 2: request exec confirm 1
+debug2: exec request accepted on channel 2
+[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: debug3: receive packet: type 2
+debug3: Received SSH2_MSG_IGNORE
+debug3: Received SSH2_MSG_IGNORE
+
+BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq
+debug3: receive packet: type 98
+''' # noqa
+
+ ssh_novvvv = (
+ b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n'
+ b'BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq\n'
+ )
+
+ dns_issue = (
+ b'timeout waiting for privilege escalation password prompt:\n'
+ b'sudo: sudo: unable to resolve host tcloud014\n'
+ b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n'
+ b'BECOME-SUCCESS-ouzmdnewuhucvuaabtjmweasarviygqq\n'
+ )
+
+ nothing = b''
+
+ in_front = b'''
+debug1: Sending command: /bin/sh -c 'sudo -H -S -p "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " -u root /bin/sh -c '"'"'echo
+'''
+
+ class ConnectionFoo(ConnectionBase):
+
+ @property
+ def transport(self):
+ pass
+
+ def _connect(self):
+ pass
+
+ def exec_command(self):
+ pass
+
+ def put_file(self):
+ pass
+
+ def fetch_file(self):
+ pass
+
+ def close(self):
+ pass
+
+ c = ConnectionFoo(self.play_context, self.in_stream)
+ c.set_become_plugin(become_loader.get('sudo'))
+ c.become.prompt = '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: '
+
+ self.assertTrue(c.become.check_password_prompt(local))
+ self.assertTrue(c.become.check_password_prompt(ssh_pipelining_vvvv))
+ self.assertTrue(c.become.check_password_prompt(ssh_nopipelining_vvvv))
+ self.assertTrue(c.become.check_password_prompt(ssh_novvvv))
+ self.assertTrue(c.become.check_password_prompt(dns_issue))
+ self.assertFalse(c.become.check_password_prompt(nothing))
+ self.assertFalse(c.become.check_password_prompt(in_front))
diff --git a/test/units/plugins/connection/test_local.py b/test/units/plugins/connection/test_local.py
new file mode 100644
index 0000000..e552585
--- /dev/null
+++ b/test/units/plugins/connection/test_local.py
@@ -0,0 +1,40 @@
+#
+# (c) 2020 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from io import StringIO
+import pytest
+
+from units.compat import unittest
+from ansible.plugins.connection import local
+from ansible.playbook.play_context import PlayContext
+
+
+class TestLocalConnectionClass(unittest.TestCase):
+
+ def test_local_connection_module(self):
+ play_context = PlayContext()
+ play_context.prompt = (
+ '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: '
+ )
+ in_stream = StringIO()
+
+ self.assertIsInstance(local.Connection(play_context, in_stream), local.Connection)
diff --git a/test/units/plugins/connection/test_paramiko.py b/test/units/plugins/connection/test_paramiko.py
new file mode 100644
index 0000000..dcf3177
--- /dev/null
+++ b/test/units/plugins/connection/test_paramiko.py
@@ -0,0 +1,56 @@
+#
+# (c) 2020 Red Hat Inc.
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from io import StringIO
+import pytest
+
+from ansible.plugins.connection import paramiko_ssh
+from ansible.playbook.play_context import PlayContext
+
+
+@pytest.fixture
+def play_context():
+ play_context = PlayContext()
+ play_context.prompt = (
+ '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: '
+ )
+
+ return play_context
+
+
+@pytest.fixture()
+def in_stream():
+ return StringIO()
+
+
+def test_paramiko_connection_module(play_context, in_stream):
+ assert isinstance(
+ paramiko_ssh.Connection(play_context, in_stream),
+ paramiko_ssh.Connection)
+
+
+def test_paramiko_connect(play_context, in_stream, mocker):
+ mocker.patch.object(paramiko_ssh.Connection, '_connect_uncached')
+ connection = paramiko_ssh.Connection(play_context, in_stream)._connect()
+
+ assert isinstance(connection, paramiko_ssh.Connection)
+ assert connection._connected is True
diff --git a/test/units/plugins/connection/test_psrp.py b/test/units/plugins/connection/test_psrp.py
new file mode 100644
index 0000000..38052e8
--- /dev/null
+++ b/test/units/plugins/connection/test_psrp.py
@@ -0,0 +1,233 @@
+# -*- coding: utf-8 -*-
+# (c) 2018, Jordan Borean <jborean@redhat.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+import sys
+
+from io import StringIO
+from unittest.mock import MagicMock
+
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.loader import connection_loader
+from ansible.utils.display import Display
+
+
+@pytest.fixture(autouse=True)
+def psrp_connection():
+ """Imports the psrp connection plugin with a mocked pypsrp module for testing"""
+
+ # Take a snapshot of sys.modules before we manipulate it
+ orig_modules = sys.modules.copy()
+ try:
+ fake_pypsrp = MagicMock()
+ fake_pypsrp.FEATURES = [
+ 'wsman_locale',
+ 'wsman_read_timeout',
+ 'wsman_reconnections',
+ ]
+
+ fake_wsman = MagicMock()
+ fake_wsman.AUTH_KWARGS = {
+ "certificate": ["certificate_key_pem", "certificate_pem"],
+ "credssp": ["credssp_auth_mechanism", "credssp_disable_tlsv1_2",
+ "credssp_minimum_version"],
+ "negotiate": ["negotiate_delegate", "negotiate_hostname_override",
+ "negotiate_send_cbt", "negotiate_service"],
+ "mock": ["mock_test1", "mock_test2"],
+ }
+
+ sys.modules["pypsrp"] = fake_pypsrp
+ sys.modules["pypsrp.complex_objects"] = MagicMock()
+ sys.modules["pypsrp.exceptions"] = MagicMock()
+ sys.modules["pypsrp.host"] = MagicMock()
+ sys.modules["pypsrp.powershell"] = MagicMock()
+ sys.modules["pypsrp.shell"] = MagicMock()
+ sys.modules["pypsrp.wsman"] = fake_wsman
+ sys.modules["requests.exceptions"] = MagicMock()
+
+ from ansible.plugins.connection import psrp
+
+ # Take a copy of the original import state vars before we set to an ok import
+ orig_has_psrp = psrp.HAS_PYPSRP
+ orig_psrp_imp_err = psrp.PYPSRP_IMP_ERR
+
+ yield psrp
+
+ psrp.HAS_PYPSRP = orig_has_psrp
+ psrp.PYPSRP_IMP_ERR = orig_psrp_imp_err
+ finally:
+ # Restore sys.modules back to our pre-shenanigans
+ sys.modules = orig_modules
+
+
+class TestConnectionPSRP(object):
+
+ OPTIONS_DATA = (
+ # default options
+ (
+ {'_extras': {}},
+ {
+ '_psrp_auth': 'negotiate',
+ '_psrp_cert_validation': True,
+ '_psrp_configuration_name': 'Microsoft.PowerShell',
+ '_psrp_connection_timeout': 30,
+ '_psrp_message_encryption': 'auto',
+ '_psrp_host': 'inventory_hostname',
+ '_psrp_conn_kwargs': {
+ 'server': 'inventory_hostname',
+ 'port': 5986,
+ 'username': None,
+ 'password': None,
+ 'ssl': True,
+ 'path': 'wsman',
+ 'auth': 'negotiate',
+ 'cert_validation': True,
+ 'connection_timeout': 30,
+ 'encryption': 'auto',
+ 'proxy': None,
+ 'no_proxy': False,
+ 'max_envelope_size': 153600,
+ 'operation_timeout': 20,
+ 'certificate_key_pem': None,
+ 'certificate_pem': None,
+ 'credssp_auth_mechanism': 'auto',
+ 'credssp_disable_tlsv1_2': False,
+ 'credssp_minimum_version': 2,
+ 'negotiate_delegate': None,
+ 'negotiate_hostname_override': None,
+ 'negotiate_send_cbt': True,
+ 'negotiate_service': 'WSMAN',
+ 'read_timeout': 30,
+ 'reconnection_backoff': 2.0,
+ 'reconnection_retries': 0,
+ },
+ '_psrp_max_envelope_size': 153600,
+ '_psrp_ignore_proxy': False,
+ '_psrp_operation_timeout': 20,
+ '_psrp_pass': None,
+ '_psrp_path': 'wsman',
+ '_psrp_port': 5986,
+ '_psrp_proxy': None,
+ '_psrp_protocol': 'https',
+ '_psrp_user': None
+ },
+ ),
+ # ssl=False when port defined to 5985
+ (
+ {'_extras': {}, 'ansible_port': '5985'},
+ {
+ '_psrp_port': 5985,
+ '_psrp_protocol': 'http'
+ },
+ ),
+ # ssl=True when port defined to not 5985
+ (
+ {'_extras': {}, 'ansible_port': 1234},
+ {
+ '_psrp_port': 1234,
+ '_psrp_protocol': 'https'
+ },
+ ),
+ # port 5986 when ssl=True
+ (
+ {'_extras': {}, 'ansible_psrp_protocol': 'https'},
+ {
+ '_psrp_port': 5986,
+ '_psrp_protocol': 'https'
+ },
+ ),
+ # port 5985 when ssl=False
+ (
+ {'_extras': {}, 'ansible_psrp_protocol': 'http'},
+ {
+ '_psrp_port': 5985,
+ '_psrp_protocol': 'http'
+ },
+ ),
+ # psrp extras
+ (
+ {'_extras': {'ansible_psrp_mock_test1': True}},
+ {
+ '_psrp_conn_kwargs': {
+ 'server': 'inventory_hostname',
+ 'port': 5986,
+ 'username': None,
+ 'password': None,
+ 'ssl': True,
+ 'path': 'wsman',
+ 'auth': 'negotiate',
+ 'cert_validation': True,
+ 'connection_timeout': 30,
+ 'encryption': 'auto',
+ 'proxy': None,
+ 'no_proxy': False,
+ 'max_envelope_size': 153600,
+ 'operation_timeout': 20,
+ 'certificate_key_pem': None,
+ 'certificate_pem': None,
+ 'credssp_auth_mechanism': 'auto',
+ 'credssp_disable_tlsv1_2': False,
+ 'credssp_minimum_version': 2,
+ 'negotiate_delegate': None,
+ 'negotiate_hostname_override': None,
+ 'negotiate_send_cbt': True,
+ 'negotiate_service': 'WSMAN',
+ 'read_timeout': 30,
+ 'reconnection_backoff': 2.0,
+ 'reconnection_retries': 0,
+ 'mock_test1': True
+ },
+ },
+ ),
+ # cert validation through string repr of bool
+ (
+ {'_extras': {}, 'ansible_psrp_cert_validation': 'ignore'},
+ {
+ '_psrp_cert_validation': False
+ },
+ ),
+ # cert validation path
+ (
+ {'_extras': {}, 'ansible_psrp_cert_trust_path': '/path/cert.pem'},
+ {
+ '_psrp_cert_validation': '/path/cert.pem'
+ },
+ ),
+ )
+
+ # pylint bug: https://github.com/PyCQA/pylint/issues/511
+ # pylint: disable=undefined-variable
+ @pytest.mark.parametrize('options, expected',
+ ((o, e) for o, e in OPTIONS_DATA))
+ def test_set_options(self, options, expected):
+ pc = PlayContext()
+ new_stdin = StringIO()
+
+ conn = connection_loader.get('psrp', pc, new_stdin)
+ conn.set_options(var_options=options)
+ conn._build_kwargs()
+
+ for attr, expected in expected.items():
+ actual = getattr(conn, attr)
+ assert actual == expected, \
+ "psrp attr '%s', actual '%s' != expected '%s'"\
+ % (attr, actual, expected)
+
+ def test_set_invalid_extras_options(self, monkeypatch):
+ pc = PlayContext()
+ new_stdin = StringIO()
+
+ conn = connection_loader.get('psrp', pc, new_stdin)
+ conn.set_options(var_options={'_extras': {'ansible_psrp_mock_test3': True}})
+
+ mock_display = MagicMock()
+ monkeypatch.setattr(Display, "warning", mock_display)
+ conn._build_kwargs()
+
+ assert mock_display.call_args[0][0] == \
+ 'ansible_psrp_mock_test3 is unsupported by the current psrp version installed'
diff --git a/test/units/plugins/connection/test_ssh.py b/test/units/plugins/connection/test_ssh.py
new file mode 100644
index 0000000..662dff9
--- /dev/null
+++ b/test/units/plugins/connection/test_ssh.py
@@ -0,0 +1,696 @@
+# -*- coding: utf-8 -*-
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from io import StringIO
+import pytest
+
+
+from ansible import constants as C
+from ansible.errors import AnsibleAuthenticationFailure
+from units.compat import unittest
+from unittest.mock import patch, MagicMock, PropertyMock
+from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
+from ansible.module_utils.compat.selectors import SelectorKey, EVENT_READ
+from ansible.module_utils.six.moves import shlex_quote
+from ansible.module_utils._text import to_bytes
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.connection import ssh
+from ansible.plugins.loader import connection_loader, become_loader
+
+
+class TestConnectionBaseClass(unittest.TestCase):
+
+ def test_plugins_connection_ssh_module(self):
+ play_context = PlayContext()
+ play_context.prompt = (
+ '[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: '
+ )
+ in_stream = StringIO()
+
+ self.assertIsInstance(ssh.Connection(play_context, in_stream), ssh.Connection)
+
+ def test_plugins_connection_ssh_basic(self):
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = ssh.Connection(pc, new_stdin)
+
+ # connect just returns self, so assert that
+ res = conn._connect()
+ self.assertEqual(conn, res)
+
+ ssh.SSHPASS_AVAILABLE = False
+ self.assertFalse(conn._sshpass_available())
+
+ ssh.SSHPASS_AVAILABLE = True
+ self.assertTrue(conn._sshpass_available())
+
+ with patch('subprocess.Popen') as p:
+ ssh.SSHPASS_AVAILABLE = None
+ p.return_value = MagicMock()
+ self.assertTrue(conn._sshpass_available())
+
+ ssh.SSHPASS_AVAILABLE = None
+ p.return_value = None
+ p.side_effect = OSError()
+ self.assertFalse(conn._sshpass_available())
+
+ conn.close()
+ self.assertFalse(conn._connected)
+
+ def test_plugins_connection_ssh__build_command(self):
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('ssh', pc, new_stdin)
+ conn.get_option = MagicMock()
+ conn.get_option.return_value = ""
+ conn._build_command('ssh', 'ssh')
+
+ def test_plugins_connection_ssh_exec_command(self):
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('ssh', pc, new_stdin)
+
+ conn._build_command = MagicMock()
+ conn._build_command.return_value = 'ssh something something'
+ conn._run = MagicMock()
+ conn._run.return_value = (0, 'stdout', 'stderr')
+ conn.get_option = MagicMock()
+ conn.get_option.return_value = True
+
+ res, stdout, stderr = conn.exec_command('ssh')
+ res, stdout, stderr = conn.exec_command('ssh', 'this is some data')
+
+ def test_plugins_connection_ssh__examine_output(self):
+ pc = PlayContext()
+ new_stdin = StringIO()
+ become_success_token = b'BECOME-SUCCESS-abcdefghijklmnopqrstuvxyz'
+
+ conn = connection_loader.get('ssh', pc, new_stdin)
+ conn.set_become_plugin(become_loader.get('sudo'))
+
+ conn.become.check_password_prompt = MagicMock()
+ conn.become.check_success = MagicMock()
+ conn.become.check_incorrect_password = MagicMock()
+ conn.become.check_missing_password = MagicMock()
+
+ def _check_password_prompt(line):
+ return b'foo' in line
+
+ def _check_become_success(line):
+ return become_success_token in line
+
+ def _check_incorrect_password(line):
+ return b'incorrect password' in line
+
+ def _check_missing_password(line):
+ return b'bad password' in line
+
+ # test examining output for prompt
+ conn._flags = dict(
+ become_prompt=False,
+ become_success=False,
+ become_error=False,
+ become_nopasswd_error=False,
+ )
+
+ pc.prompt = True
+
+ # override become plugin
+ conn.become.prompt = True
+ conn.become.check_password_prompt = MagicMock(side_effect=_check_password_prompt)
+ conn.become.check_success = MagicMock(side_effect=_check_become_success)
+ conn.become.check_incorrect_password = MagicMock(side_effect=_check_incorrect_password)
+ conn.become.check_missing_password = MagicMock(side_effect=_check_missing_password)
+
+ def get_option(option):
+ if option == 'become_pass':
+ return 'password'
+ return None
+
+ conn.become.get_option = get_option
+ output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\nfoo\nline 3\nthis should be the remainder', False)
+ self.assertEqual(output, b'line 1\nline 2\nline 3\n')
+ self.assertEqual(unprocessed, b'this should be the remainder')
+ self.assertTrue(conn._flags['become_prompt'])
+ self.assertFalse(conn._flags['become_success'])
+ self.assertFalse(conn._flags['become_error'])
+ self.assertFalse(conn._flags['become_nopasswd_error'])
+
+ # test examining output for become prompt
+ conn._flags = dict(
+ become_prompt=False,
+ become_success=False,
+ become_error=False,
+ become_nopasswd_error=False,
+ )
+
+ pc.prompt = False
+ conn.become.prompt = False
+ pc.success_key = str(become_success_token)
+ conn.become.success = str(become_success_token)
+ output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\n%s\nline 3\n' % become_success_token, False)
+ self.assertEqual(output, b'line 1\nline 2\nline 3\n')
+ self.assertEqual(unprocessed, b'')
+ self.assertFalse(conn._flags['become_prompt'])
+ self.assertTrue(conn._flags['become_success'])
+ self.assertFalse(conn._flags['become_error'])
+ self.assertFalse(conn._flags['become_nopasswd_error'])
+
+ # test we dont detect become success from ssh debug: lines
+ conn._flags = dict(
+ become_prompt=False,
+ become_success=False,
+ become_error=False,
+ become_nopasswd_error=False,
+ )
+
+ pc.prompt = False
+ conn.become.prompt = True
+ pc.success_key = str(become_success_token)
+ conn.become.success = str(become_success_token)
+ output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\ndebug1: %s\nline 3\n' % become_success_token, False)
+ self.assertEqual(output, b'line 1\nline 2\ndebug1: %s\nline 3\n' % become_success_token)
+ self.assertEqual(unprocessed, b'')
+ self.assertFalse(conn._flags['become_success'])
+
+ # test examining output for become failure
+ conn._flags = dict(
+ become_prompt=False,
+ become_success=False,
+ become_error=False,
+ become_nopasswd_error=False,
+ )
+
+ pc.prompt = False
+ conn.become.prompt = False
+ pc.success_key = None
+ output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nline 2\nincorrect password\n', True)
+ self.assertEqual(output, b'line 1\nline 2\nincorrect password\n')
+ self.assertEqual(unprocessed, b'')
+ self.assertFalse(conn._flags['become_prompt'])
+ self.assertFalse(conn._flags['become_success'])
+ self.assertTrue(conn._flags['become_error'])
+ self.assertFalse(conn._flags['become_nopasswd_error'])
+
+ # test examining output for missing password
+ conn._flags = dict(
+ become_prompt=False,
+ become_success=False,
+ become_error=False,
+ become_nopasswd_error=False,
+ )
+
+ pc.prompt = False
+ conn.become.prompt = False
+ pc.success_key = None
+ output, unprocessed = conn._examine_output(u'source', u'state', b'line 1\nbad password\n', True)
+ self.assertEqual(output, b'line 1\nbad password\n')
+ self.assertEqual(unprocessed, b'')
+ self.assertFalse(conn._flags['become_prompt'])
+ self.assertFalse(conn._flags['become_success'])
+ self.assertFalse(conn._flags['become_error'])
+ self.assertTrue(conn._flags['become_nopasswd_error'])
+
+ @patch('time.sleep')
+ @patch('os.path.exists')
+ def test_plugins_connection_ssh_put_file(self, mock_ospe, mock_sleep):
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('ssh', pc, new_stdin)
+ conn._build_command = MagicMock()
+ conn._bare_run = MagicMock()
+
+ mock_ospe.return_value = True
+ conn._build_command.return_value = 'some command to run'
+ conn._bare_run.return_value = (0, '', '')
+ conn.host = "some_host"
+
+ conn.set_option('reconnection_retries', 9)
+ conn.set_option('ssh_transfer_method', None) # unless set to None scp_if_ssh is ignored
+
+ # Test with SCP_IF_SSH set to smart
+ # Test when SFTP works
+ conn.set_option('scp_if_ssh', 'smart')
+ expected_in_data = b' '.join((b'put', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n'
+ conn.put_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False)
+
+ # Test when SFTP doesn't work but SCP does
+ conn._bare_run.side_effect = [(1, 'stdout', 'some errors'), (0, '', '')]
+ conn.put_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', None, checkrc=False)
+ conn._bare_run.side_effect = None
+
+ # test with SCP_IF_SSH enabled
+ conn.set_option('scp_if_ssh', True)
+ conn.put_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', None, checkrc=False)
+
+ conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩')
+ conn._bare_run.assert_called_with('some command to run', None, checkrc=False)
+
+ # test with SCPP_IF_SSH disabled
+ conn.set_option('scp_if_ssh', False)
+ expected_in_data = b' '.join((b'put', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n'
+ conn.put_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False)
+
+ expected_in_data = b' '.join((b'put',
+ to_bytes(shlex_quote('/path/to/in/file/with/unicode-fö〩')),
+ to_bytes(shlex_quote('/path/to/dest/file/with/unicode-fö〩')))) + b'\n'
+ conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩')
+ conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False)
+
+ # test that a non-zero rc raises an error
+ conn._bare_run.return_value = (1, 'stdout', 'some errors')
+ self.assertRaises(AnsibleError, conn.put_file, '/path/to/bad/file', '/remote/path/to/file')
+
+ # test that a not-found path raises an error
+ mock_ospe.return_value = False
+ conn._bare_run.return_value = (0, 'stdout', '')
+ self.assertRaises(AnsibleFileNotFound, conn.put_file, '/path/to/bad/file', '/remote/path/to/file')
+
+ @patch('time.sleep')
+ def test_plugins_connection_ssh_fetch_file(self, mock_sleep):
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('ssh', pc, new_stdin)
+ conn._build_command = MagicMock()
+ conn._bare_run = MagicMock()
+ conn._load_name = 'ssh'
+
+ conn._build_command.return_value = 'some command to run'
+ conn._bare_run.return_value = (0, '', '')
+ conn.host = "some_host"
+
+ conn.set_option('reconnection_retries', 9)
+ conn.set_option('ssh_transfer_method', None) # unless set to None scp_if_ssh is ignored
+
+ # Test with SCP_IF_SSH set to smart
+ # Test when SFTP works
+ conn.set_option('scp_if_ssh', 'smart')
+ expected_in_data = b' '.join((b'get', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n'
+ conn.set_options({})
+ conn.fetch_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False)
+
+ # Test when SFTP doesn't work but SCP does
+ conn._bare_run.side_effect = [(1, 'stdout', 'some errors'), (0, '', '')]
+ conn.fetch_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', None, checkrc=False)
+
+ # test with SCP_IF_SSH enabled
+ conn._bare_run.side_effect = None
+ conn.set_option('ssh_transfer_method', None) # unless set to None scp_if_ssh is ignored
+ conn.set_option('scp_if_ssh', 'True')
+ conn.fetch_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', None, checkrc=False)
+
+ conn.fetch_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩')
+ conn._bare_run.assert_called_with('some command to run', None, checkrc=False)
+
+ # test with SCP_IF_SSH disabled
+ conn.set_option('scp_if_ssh', False)
+ expected_in_data = b' '.join((b'get', to_bytes(shlex_quote('/path/to/in/file')), to_bytes(shlex_quote('/path/to/dest/file')))) + b'\n'
+ conn.fetch_file('/path/to/in/file', '/path/to/dest/file')
+ conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False)
+
+ expected_in_data = b' '.join((b'get',
+ to_bytes(shlex_quote('/path/to/in/file/with/unicode-fö〩')),
+ to_bytes(shlex_quote('/path/to/dest/file/with/unicode-fö〩')))) + b'\n'
+ conn.fetch_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩')
+ conn._bare_run.assert_called_with('some command to run', expected_in_data, checkrc=False)
+
+ # test that a non-zero rc raises an error
+ conn._bare_run.return_value = (1, 'stdout', 'some errors')
+ self.assertRaises(AnsibleError, conn.fetch_file, '/path/to/bad/file', '/remote/path/to/file')
+
+
+class MockSelector(object):
+ def __init__(self):
+ self.files_watched = 0
+ self.register = MagicMock(side_effect=self._register)
+ self.unregister = MagicMock(side_effect=self._unregister)
+ self.close = MagicMock()
+ self.get_map = MagicMock(side_effect=self._get_map)
+ self.select = MagicMock()
+
+ def _register(self, *args, **kwargs):
+ self.files_watched += 1
+
+ def _unregister(self, *args, **kwargs):
+ self.files_watched -= 1
+
+ def _get_map(self, *args, **kwargs):
+ return self.files_watched
+
+
+@pytest.fixture
+def mock_run_env(request, mocker):
+ pc = PlayContext()
+ new_stdin = StringIO()
+
+ conn = connection_loader.get('ssh', pc, new_stdin)
+ conn.set_become_plugin(become_loader.get('sudo'))
+ conn._send_initial_data = MagicMock()
+ conn._examine_output = MagicMock()
+ conn._terminate_process = MagicMock()
+ conn._load_name = 'ssh'
+ conn.sshpass_pipe = [MagicMock(), MagicMock()]
+
+ request.cls.pc = pc
+ request.cls.conn = conn
+
+ mock_popen_res = MagicMock()
+ mock_popen_res.poll = MagicMock()
+ mock_popen_res.wait = MagicMock()
+ mock_popen_res.stdin = MagicMock()
+ mock_popen_res.stdin.fileno.return_value = 1000
+ mock_popen_res.stdout = MagicMock()
+ mock_popen_res.stdout.fileno.return_value = 1001
+ mock_popen_res.stderr = MagicMock()
+ mock_popen_res.stderr.fileno.return_value = 1002
+ mock_popen_res.returncode = 0
+ request.cls.mock_popen_res = mock_popen_res
+
+ mock_popen = mocker.patch('subprocess.Popen', return_value=mock_popen_res)
+ request.cls.mock_popen = mock_popen
+
+ request.cls.mock_selector = MockSelector()
+ mocker.patch('ansible.module_utils.compat.selectors.DefaultSelector', lambda: request.cls.mock_selector)
+
+ request.cls.mock_openpty = mocker.patch('pty.openpty')
+
+ mocker.patch('fcntl.fcntl')
+ mocker.patch('os.write')
+ mocker.patch('os.close')
+
+
+@pytest.mark.usefixtures('mock_run_env')
+class TestSSHConnectionRun(object):
+ # FIXME:
+ # These tests are little more than a smoketest. Need to enhance them
+ # a bit to check that they're calling the relevant functions and making
+ # complete coverage of the code paths
+ def test_no_escalation(self):
+ self.mock_popen_res.stdout.read.side_effect = [b"my_stdout\n", b"second_line"]
+ self.mock_popen_res.stderr.read.side_effect = [b"my_stderr"]
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ []]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data")
+ assert return_code == 0
+ assert b_stdout == b'my_stdout\nsecond_line'
+ assert b_stderr == b'my_stderr'
+ assert self.mock_selector.register.called is True
+ assert self.mock_selector.register.call_count == 2
+ assert self.conn._send_initial_data.called is True
+ assert self.conn._send_initial_data.call_count == 1
+ assert self.conn._send_initial_data.call_args[0][1] == 'this is input data'
+
+ def test_with_password(self):
+ # test with a password set to trigger the sshpass write
+ self.pc.password = '12345'
+ self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
+ self.mock_popen_res.stderr.read.side_effect = [b""]
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ []]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ return_code, b_stdout, b_stderr = self.conn._run(["ssh", "is", "a", "cmd"], "this is more data")
+ assert return_code == 0
+ assert b_stdout == b'some data'
+ assert b_stderr == b''
+ assert self.mock_selector.register.called is True
+ assert self.mock_selector.register.call_count == 2
+ assert self.conn._send_initial_data.called is True
+ assert self.conn._send_initial_data.call_count == 1
+ assert self.conn._send_initial_data.call_args[0][1] == 'this is more data'
+
+ def _password_with_prompt_examine_output(self, sourice, state, b_chunk, sudoable):
+ if state == 'awaiting_prompt':
+ self.conn._flags['become_prompt'] = True
+ elif state == 'awaiting_escalation':
+ self.conn._flags['become_success'] = True
+ return (b'', b'')
+
+ def test_password_with_prompt(self):
+ # test with password prompting enabled
+ self.pc.password = None
+ self.conn.become.prompt = b'Password:'
+ self.conn._examine_output.side_effect = self._password_with_prompt_examine_output
+ self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"Success", b""]
+ self.mock_popen_res.stderr.read.side_effect = [b""]
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ),
+ (SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ []]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data")
+ assert return_code == 0
+ assert b_stdout == b''
+ assert b_stderr == b''
+ assert self.mock_selector.register.called is True
+ assert self.mock_selector.register.call_count == 2
+ assert self.conn._send_initial_data.called is True
+ assert self.conn._send_initial_data.call_count == 1
+ assert self.conn._send_initial_data.call_args[0][1] == 'this is input data'
+
+ def test_password_with_become(self):
+ # test with some become settings
+ self.pc.prompt = b'Password:'
+ self.conn.become.prompt = b'Password:'
+ self.pc.become = True
+ self.pc.success_key = 'BECOME-SUCCESS-abcdefg'
+ self.conn.become._id = 'abcdefg'
+ self.conn._examine_output.side_effect = self._password_with_prompt_examine_output
+ self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"BECOME-SUCCESS-abcdefg", b"abc"]
+ self.mock_popen_res.stderr.read.side_effect = [b"123"]
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ []]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data")
+ self.mock_popen_res.stdin.flush.assert_called_once_with()
+ assert return_code == 0
+ assert b_stdout == b'abc'
+ assert b_stderr == b'123'
+ assert self.mock_selector.register.called is True
+ assert self.mock_selector.register.call_count == 2
+ assert self.conn._send_initial_data.called is True
+ assert self.conn._send_initial_data.call_count == 1
+ assert self.conn._send_initial_data.call_args[0][1] == 'this is input data'
+
+ def test_pasword_without_data(self):
+ # simulate no data input but Popen using new pty's fails
+ self.mock_popen.return_value = None
+ self.mock_popen.side_effect = [OSError(), self.mock_popen_res]
+
+ # simulate no data input
+ self.mock_openpty.return_value = (98, 99)
+ self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""]
+ self.mock_popen_res.stderr.read.side_effect = [b""]
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ []]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ return_code, b_stdout, b_stderr = self.conn._run("ssh", "")
+ assert return_code == 0
+ assert b_stdout == b'some data'
+ assert b_stderr == b''
+ assert self.mock_selector.register.called is True
+ assert self.mock_selector.register.call_count == 2
+ assert self.conn._send_initial_data.called is False
+
+
+@pytest.mark.usefixtures('mock_run_env')
+class TestSSHConnectionRetries(object):
+ def test_incorrect_password(self, monkeypatch):
+ self.conn.set_option('host_key_checking', False)
+ self.conn.set_option('reconnection_retries', 5)
+ monkeypatch.setattr('time.sleep', lambda x: None)
+
+ self.mock_popen_res.stdout.read.side_effect = [b'']
+ self.mock_popen_res.stderr.read.side_effect = [b'Permission denied, please try again.\r\n']
+ type(self.mock_popen_res).returncode = PropertyMock(side_effect=[5] * 4)
+
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [],
+ ]
+
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ self.conn._build_command = MagicMock()
+ self.conn._build_command.return_value = [b'sshpass', b'-d41', b'ssh', b'-C']
+
+ exception_info = pytest.raises(AnsibleAuthenticationFailure, self.conn.exec_command, 'sshpass', 'some data')
+ assert exception_info.value.message == ('Invalid/incorrect username/password. Skipping remaining 5 retries to prevent account lockout: '
+ 'Permission denied, please try again.')
+ assert self.mock_popen.call_count == 1
+
+ def test_retry_then_success(self, monkeypatch):
+ self.conn.set_option('host_key_checking', False)
+ self.conn.set_option('reconnection_retries', 3)
+
+ monkeypatch.setattr('time.sleep', lambda x: None)
+
+ self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"]
+ self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"]
+ type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4)
+
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ []
+ ]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ self.conn._build_command = MagicMock()
+ self.conn._build_command.return_value = 'ssh'
+
+ return_code, b_stdout, b_stderr = self.conn.exec_command('ssh', 'some data')
+ assert return_code == 0
+ assert b_stdout == b'my_stdout\nsecond_line'
+ assert b_stderr == b'my_stderr'
+
+ def test_multiple_failures(self, monkeypatch):
+ self.conn.set_option('host_key_checking', False)
+ self.conn.set_option('reconnection_retries', 9)
+
+ monkeypatch.setattr('time.sleep', lambda x: None)
+
+ self.mock_popen_res.stdout.read.side_effect = [b""] * 10
+ self.mock_popen_res.stderr.read.side_effect = [b""] * 10
+ type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 30)
+
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [],
+ ] * 10
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ self.conn._build_command = MagicMock()
+ self.conn._build_command.return_value = 'ssh'
+
+ pytest.raises(AnsibleConnectionFailure, self.conn.exec_command, 'ssh', 'some data')
+ assert self.mock_popen.call_count == 10
+
+ def test_abitrary_exceptions(self, monkeypatch):
+ self.conn.set_option('host_key_checking', False)
+ self.conn.set_option('reconnection_retries', 9)
+
+ monkeypatch.setattr('time.sleep', lambda x: None)
+
+ self.conn._build_command = MagicMock()
+ self.conn._build_command.return_value = 'ssh'
+
+ self.mock_popen.side_effect = [Exception('bad')] * 10
+ pytest.raises(Exception, self.conn.exec_command, 'ssh', 'some data')
+ assert self.mock_popen.call_count == 10
+
+ def test_put_file_retries(self, monkeypatch):
+ self.conn.set_option('host_key_checking', False)
+ self.conn.set_option('reconnection_retries', 3)
+
+ monkeypatch.setattr('time.sleep', lambda x: None)
+ monkeypatch.setattr('ansible.plugins.connection.ssh.os.path.exists', lambda x: True)
+
+ self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"]
+ self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"]
+ type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 4 + [0] * 4)
+
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ []
+ ]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ self.conn._build_command = MagicMock()
+ self.conn._build_command.return_value = 'sftp'
+
+ return_code, b_stdout, b_stderr = self.conn.put_file('/path/to/in/file', '/path/to/dest/file')
+ assert return_code == 0
+ assert b_stdout == b"my_stdout\nsecond_line"
+ assert b_stderr == b"my_stderr"
+ assert self.mock_popen.call_count == 2
+
+ def test_fetch_file_retries(self, monkeypatch):
+ self.conn.set_option('host_key_checking', False)
+ self.conn.set_option('reconnection_retries', 3)
+
+ monkeypatch.setattr('time.sleep', lambda x: None)
+ monkeypatch.setattr('ansible.plugins.connection.ssh.os.path.exists', lambda x: True)
+
+ self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"]
+ self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"]
+ type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 4 + [0] * 4)
+
+ self.mock_selector.select.side_effect = [
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ [],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
+ [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
+ []
+ ]
+ self.mock_selector.get_map.side_effect = lambda: True
+
+ self.conn._build_command = MagicMock()
+ self.conn._build_command.return_value = 'sftp'
+
+ return_code, b_stdout, b_stderr = self.conn.fetch_file('/path/to/in/file', '/path/to/dest/file')
+ assert return_code == 0
+ assert b_stdout == b"my_stdout\nsecond_line"
+ assert b_stderr == b"my_stderr"
+ assert self.mock_popen.call_count == 2
diff --git a/test/units/plugins/connection/test_winrm.py b/test/units/plugins/connection/test_winrm.py
new file mode 100644
index 0000000..cb52814
--- /dev/null
+++ b/test/units/plugins/connection/test_winrm.py
@@ -0,0 +1,443 @@
+# -*- coding: utf-8 -*-
+# (c) 2018, Jordan Borean <jborean@redhat.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+
+import pytest
+
+from io import StringIO
+
+from unittest.mock import MagicMock
+from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils._text import to_bytes
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.loader import connection_loader
+from ansible.plugins.connection import winrm
+
+pytest.importorskip("winrm")
+
+
+class TestConnectionWinRM(object):
+
+ OPTIONS_DATA = (
+ # default options
+ (
+ {'_extras': {}},
+ {},
+ {
+ '_kerb_managed': False,
+ '_kinit_cmd': 'kinit',
+ '_winrm_connection_timeout': None,
+ '_winrm_host': 'inventory_hostname',
+ '_winrm_kwargs': {'username': None, 'password': None},
+ '_winrm_pass': None,
+ '_winrm_path': '/wsman',
+ '_winrm_port': 5986,
+ '_winrm_scheme': 'https',
+ '_winrm_transport': ['ssl'],
+ '_winrm_user': None
+ },
+ False
+ ),
+ # http through port
+ (
+ {'_extras': {}, 'ansible_port': 5985},
+ {},
+ {
+ '_winrm_kwargs': {'username': None, 'password': None},
+ '_winrm_port': 5985,
+ '_winrm_scheme': 'http',
+ '_winrm_transport': ['plaintext'],
+ },
+ False
+ ),
+ # kerberos user with kerb present
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com'},
+ {},
+ {
+ '_kerb_managed': False,
+ '_kinit_cmd': 'kinit',
+ '_winrm_kwargs': {'username': 'user@domain.com',
+ 'password': None},
+ '_winrm_pass': None,
+ '_winrm_transport': ['kerberos', 'ssl'],
+ '_winrm_user': 'user@domain.com'
+ },
+ True
+ ),
+ # kerberos user without kerb present
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com'},
+ {},
+ {
+ '_kerb_managed': False,
+ '_kinit_cmd': 'kinit',
+ '_winrm_kwargs': {'username': 'user@domain.com',
+ 'password': None},
+ '_winrm_pass': None,
+ '_winrm_transport': ['ssl'],
+ '_winrm_user': 'user@domain.com'
+ },
+ False
+ ),
+ # kerberos user with managed ticket (implicit)
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com'},
+ {'remote_password': 'pass'},
+ {
+ '_kerb_managed': True,
+ '_kinit_cmd': 'kinit',
+ '_winrm_kwargs': {'username': 'user@domain.com',
+ 'password': 'pass'},
+ '_winrm_pass': 'pass',
+ '_winrm_transport': ['kerberos', 'ssl'],
+ '_winrm_user': 'user@domain.com'
+ },
+ True
+ ),
+ # kerb with managed ticket (explicit)
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com',
+ 'ansible_winrm_kinit_mode': 'managed'},
+ {'password': 'pass'},
+ {
+ '_kerb_managed': True,
+ },
+ True
+ ),
+ # kerb with unmanaged ticket (explicit))
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com',
+ 'ansible_winrm_kinit_mode': 'manual'},
+ {'password': 'pass'},
+ {
+ '_kerb_managed': False,
+ },
+ True
+ ),
+ # transport override (single)
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com',
+ 'ansible_winrm_transport': 'ntlm'},
+ {},
+ {
+ '_winrm_kwargs': {'username': 'user@domain.com',
+ 'password': None},
+ '_winrm_pass': None,
+ '_winrm_transport': ['ntlm'],
+ },
+ False
+ ),
+ # transport override (list)
+ (
+ {'_extras': {}, 'ansible_user': 'user@domain.com',
+ 'ansible_winrm_transport': ['ntlm', 'certificate']},
+ {},
+ {
+ '_winrm_kwargs': {'username': 'user@domain.com',
+ 'password': None},
+ '_winrm_pass': None,
+ '_winrm_transport': ['ntlm', 'certificate'],
+ },
+ False
+ ),
+ # winrm extras
+ (
+ {'_extras': {'ansible_winrm_server_cert_validation': 'ignore',
+ 'ansible_winrm_service': 'WSMAN'}},
+ {},
+ {
+ '_winrm_kwargs': {'username': None, 'password': None,
+ 'server_cert_validation': 'ignore',
+ 'service': 'WSMAN'},
+ },
+ False
+ ),
+ # direct override
+ (
+ {'_extras': {}, 'ansible_winrm_connection_timeout': 5},
+ {'connection_timeout': 10},
+ {
+ '_winrm_connection_timeout': 10,
+ },
+ False
+ ),
+ # password as ansible_password
+ (
+ {'_extras': {}, 'ansible_password': 'pass'},
+ {},
+ {
+ '_winrm_pass': 'pass',
+ '_winrm_kwargs': {'username': None, 'password': 'pass'}
+ },
+ False
+ ),
+ # password as ansible_winrm_pass
+ (
+ {'_extras': {}, 'ansible_winrm_pass': 'pass'},
+ {},
+ {
+ '_winrm_pass': 'pass',
+ '_winrm_kwargs': {'username': None, 'password': 'pass'}
+ },
+ False
+ ),
+
+ # password as ansible_winrm_password
+ (
+ {'_extras': {}, 'ansible_winrm_password': 'pass'},
+ {},
+ {
+ '_winrm_pass': 'pass',
+ '_winrm_kwargs': {'username': None, 'password': 'pass'}
+ },
+ False
+ ),
+ )
+
+ # pylint bug: https://github.com/PyCQA/pylint/issues/511
+ # pylint: disable=undefined-variable
+ @pytest.mark.parametrize('options, direct, expected, kerb',
+ ((o, d, e, k) for o, d, e, k in OPTIONS_DATA))
+ def test_set_options(self, options, direct, expected, kerb):
+ winrm.HAVE_KERBEROS = kerb
+
+ pc = PlayContext()
+ new_stdin = StringIO()
+
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options=options, direct=direct)
+ conn._build_winrm_kwargs()
+
+ for attr, expected in expected.items():
+ actual = getattr(conn, attr)
+ assert actual == expected, \
+ "winrm attr '%s', actual '%s' != expected '%s'"\
+ % (attr, actual, expected)
+
+
+class TestWinRMKerbAuth(object):
+
+ @pytest.mark.parametrize('options, expected', [
+ [{"_extras": {}},
+ (["kinit", "user@domain"],)],
+ [{"_extras": {}, 'ansible_winrm_kinit_cmd': 'kinit2'},
+ (["kinit2", "user@domain"],)],
+ [{"_extras": {'ansible_winrm_kerberos_delegation': True}},
+ (["kinit", "-f", "user@domain"],)],
+ [{"_extras": {}, 'ansible_winrm_kinit_args': '-f -p'},
+ (["kinit", "-f", "-p", "user@domain"],)],
+ [{"_extras": {}, 'ansible_winrm_kerberos_delegation': True, 'ansible_winrm_kinit_args': '-p'},
+ (["kinit", "-p", "user@domain"],)]
+ ])
+ def test_kinit_success_subprocess(self, monkeypatch, options, expected):
+ def mock_communicate(input=None, timeout=None):
+ return b"", b""
+
+ mock_popen = MagicMock()
+ mock_popen.return_value.communicate = mock_communicate
+ mock_popen.return_value.returncode = 0
+ monkeypatch.setattr("subprocess.Popen", mock_popen)
+
+ winrm.HAS_PEXPECT = False
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options=options)
+ conn._build_winrm_kwargs()
+
+ conn._kerb_auth("user@domain", "pass")
+ mock_calls = mock_popen.mock_calls
+ assert len(mock_calls) == 1
+ assert mock_calls[0][1] == expected
+ actual_env = mock_calls[0][2]['env']
+ assert sorted(list(actual_env.keys())) == ['KRB5CCNAME', 'PATH']
+ assert actual_env['KRB5CCNAME'].startswith("FILE:/")
+ assert actual_env['PATH'] == os.environ['PATH']
+
+ @pytest.mark.parametrize('options, expected', [
+ [{"_extras": {}},
+ ("kinit", ["user@domain"],)],
+ [{"_extras": {}, 'ansible_winrm_kinit_cmd': 'kinit2'},
+ ("kinit2", ["user@domain"],)],
+ [{"_extras": {'ansible_winrm_kerberos_delegation': True}},
+ ("kinit", ["-f", "user@domain"],)],
+ [{"_extras": {}, 'ansible_winrm_kinit_args': '-f -p'},
+ ("kinit", ["-f", "-p", "user@domain"],)],
+ [{"_extras": {}, 'ansible_winrm_kerberos_delegation': True, 'ansible_winrm_kinit_args': '-p'},
+ ("kinit", ["-p", "user@domain"],)]
+ ])
+ def test_kinit_success_pexpect(self, monkeypatch, options, expected):
+ pytest.importorskip("pexpect")
+ mock_pexpect = MagicMock()
+ mock_pexpect.return_value.exitstatus = 0
+ monkeypatch.setattr("pexpect.spawn", mock_pexpect)
+
+ winrm.HAS_PEXPECT = True
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options=options)
+ conn._build_winrm_kwargs()
+
+ conn._kerb_auth("user@domain", "pass")
+ mock_calls = mock_pexpect.mock_calls
+ assert mock_calls[0][1] == expected
+ actual_env = mock_calls[0][2]['env']
+ assert sorted(list(actual_env.keys())) == ['KRB5CCNAME', 'PATH']
+ assert actual_env['KRB5CCNAME'].startswith("FILE:/")
+ assert actual_env['PATH'] == os.environ['PATH']
+ assert mock_calls[0][2]['echo'] is False
+ assert mock_calls[1][0] == "().expect"
+ assert mock_calls[1][1] == (".*:",)
+ assert mock_calls[2][0] == "().sendline"
+ assert mock_calls[2][1] == ("pass",)
+ assert mock_calls[3][0] == "().read"
+ assert mock_calls[4][0] == "().wait"
+
+ def test_kinit_with_missing_executable_subprocess(self, monkeypatch):
+ expected_err = "[Errno 2] No such file or directory: " \
+ "'/fake/kinit': '/fake/kinit'"
+ mock_popen = MagicMock(side_effect=OSError(expected_err))
+
+ monkeypatch.setattr("subprocess.Popen", mock_popen)
+
+ winrm.HAS_PEXPECT = False
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ options = {"_extras": {}, "ansible_winrm_kinit_cmd": "/fake/kinit"}
+ conn.set_options(var_options=options)
+ conn._build_winrm_kwargs()
+
+ with pytest.raises(AnsibleConnectionFailure) as err:
+ conn._kerb_auth("user@domain", "pass")
+ assert str(err.value) == "Kerberos auth failure when calling " \
+ "kinit cmd '/fake/kinit': %s" % expected_err
+
+ def test_kinit_with_missing_executable_pexpect(self, monkeypatch):
+ pexpect = pytest.importorskip("pexpect")
+
+ expected_err = "The command was not found or was not " \
+ "executable: /fake/kinit"
+ mock_pexpect = \
+ MagicMock(side_effect=pexpect.ExceptionPexpect(expected_err))
+
+ monkeypatch.setattr("pexpect.spawn", mock_pexpect)
+
+ winrm.HAS_PEXPECT = True
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ options = {"_extras": {}, "ansible_winrm_kinit_cmd": "/fake/kinit"}
+ conn.set_options(var_options=options)
+ conn._build_winrm_kwargs()
+
+ with pytest.raises(AnsibleConnectionFailure) as err:
+ conn._kerb_auth("user@domain", "pass")
+ assert str(err.value) == "Kerberos auth failure when calling " \
+ "kinit cmd '/fake/kinit': %s" % expected_err
+
+ def test_kinit_error_subprocess(self, monkeypatch):
+ expected_err = "kinit: krb5_parse_name: " \
+ "Configuration file does not specify default realm"
+
+ def mock_communicate(input=None, timeout=None):
+ return b"", to_bytes(expected_err)
+
+ mock_popen = MagicMock()
+ mock_popen.return_value.communicate = mock_communicate
+ mock_popen.return_value.returncode = 1
+ monkeypatch.setattr("subprocess.Popen", mock_popen)
+
+ winrm.HAS_PEXPECT = False
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options={"_extras": {}})
+ conn._build_winrm_kwargs()
+
+ with pytest.raises(AnsibleConnectionFailure) as err:
+ conn._kerb_auth("invaliduser", "pass")
+
+ assert str(err.value) == \
+ "Kerberos auth failure for principal invaliduser with " \
+ "subprocess: %s" % (expected_err)
+
+ def test_kinit_error_pexpect(self, monkeypatch):
+ pytest.importorskip("pexpect")
+
+ expected_err = "Configuration file does not specify default realm"
+ mock_pexpect = MagicMock()
+ mock_pexpect.return_value.expect = MagicMock(side_effect=OSError)
+ mock_pexpect.return_value.read.return_value = to_bytes(expected_err)
+ mock_pexpect.return_value.exitstatus = 1
+
+ monkeypatch.setattr("pexpect.spawn", mock_pexpect)
+
+ winrm.HAS_PEXPECT = True
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options={"_extras": {}})
+ conn._build_winrm_kwargs()
+
+ with pytest.raises(AnsibleConnectionFailure) as err:
+ conn._kerb_auth("invaliduser", "pass")
+
+ assert str(err.value) == \
+ "Kerberos auth failure for principal invaliduser with " \
+ "pexpect: %s" % (expected_err)
+
+ def test_kinit_error_pass_in_output_subprocess(self, monkeypatch):
+ def mock_communicate(input=None, timeout=None):
+ return b"", b"Error with kinit\n" + input
+
+ mock_popen = MagicMock()
+ mock_popen.return_value.communicate = mock_communicate
+ mock_popen.return_value.returncode = 1
+ monkeypatch.setattr("subprocess.Popen", mock_popen)
+
+ winrm.HAS_PEXPECT = False
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options={"_extras": {}})
+ conn._build_winrm_kwargs()
+
+ with pytest.raises(AnsibleConnectionFailure) as err:
+ conn._kerb_auth("username", "password")
+ assert str(err.value) == \
+ "Kerberos auth failure for principal username with subprocess: " \
+ "Error with kinit\n<redacted>"
+
+ def test_kinit_error_pass_in_output_pexpect(self, monkeypatch):
+ pytest.importorskip("pexpect")
+
+ mock_pexpect = MagicMock()
+ mock_pexpect.return_value.expect = MagicMock()
+ mock_pexpect.return_value.read.return_value = \
+ b"Error with kinit\npassword\n"
+ mock_pexpect.return_value.exitstatus = 1
+
+ monkeypatch.setattr("pexpect.spawn", mock_pexpect)
+
+ winrm.HAS_PEXPECT = True
+ pc = PlayContext()
+ pc = PlayContext()
+ new_stdin = StringIO()
+ conn = connection_loader.get('winrm', pc, new_stdin)
+ conn.set_options(var_options={"_extras": {}})
+ conn._build_winrm_kwargs()
+
+ with pytest.raises(AnsibleConnectionFailure) as err:
+ conn._kerb_auth("username", "password")
+ assert str(err.value) == \
+ "Kerberos auth failure for principal username with pexpect: " \
+ "Error with kinit\n<redacted>"
diff --git a/test/units/plugins/filter/__init__.py b/test/units/plugins/filter/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/filter/__init__.py
diff --git a/test/units/plugins/filter/test_core.py b/test/units/plugins/filter/test_core.py
new file mode 100644
index 0000000..df4e472
--- /dev/null
+++ b/test/units/plugins/filter/test_core.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+from jinja2.runtime import Undefined
+from jinja2.exceptions import UndefinedError
+__metaclass__ = type
+
+import pytest
+
+from ansible.module_utils._text import to_native
+from ansible.plugins.filter.core import to_uuid
+from ansible.errors import AnsibleFilterError
+
+
+UUID_DEFAULT_NAMESPACE_TEST_CASES = (
+ ('example.com', 'ae780c3a-a3ab-53c2-bfb4-098da300b3fe'),
+ ('test.example', '8e437a35-c7c5-50ea-867c-5c254848dbc2'),
+ ('café.example', '8a99d6b1-fb8f-5f78-af86-879768589f56'),
+)
+
+UUID_TEST_CASES = (
+ ('361E6D51-FAEC-444A-9079-341386DA8E2E', 'example.com', 'ae780c3a-a3ab-53c2-bfb4-098da300b3fe'),
+ ('361E6D51-FAEC-444A-9079-341386DA8E2E', 'test.example', '8e437a35-c7c5-50ea-867c-5c254848dbc2'),
+ ('11111111-2222-3333-4444-555555555555', 'example.com', 'e776faa5-5299-55dc-9057-7a00e6be2364'),
+)
+
+
+@pytest.mark.parametrize('value, expected', UUID_DEFAULT_NAMESPACE_TEST_CASES)
+def test_to_uuid_default_namespace(value, expected):
+ assert expected == to_uuid(value)
+
+
+@pytest.mark.parametrize('namespace, value, expected', UUID_TEST_CASES)
+def test_to_uuid(namespace, value, expected):
+ assert expected == to_uuid(value, namespace=namespace)
+
+
+def test_to_uuid_invalid_namespace():
+ with pytest.raises(AnsibleFilterError) as e:
+ to_uuid('example.com', namespace='11111111-2222-3333-4444-555555555')
+ assert 'Invalid value' in to_native(e.value)
diff --git a/test/units/plugins/filter/test_mathstuff.py b/test/units/plugins/filter/test_mathstuff.py
new file mode 100644
index 0000000..f793871
--- /dev/null
+++ b/test/units/plugins/filter/test_mathstuff.py
@@ -0,0 +1,162 @@
+# Copyright: (c) 2017, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+import pytest
+
+from jinja2 import Environment
+
+import ansible.plugins.filter.mathstuff as ms
+from ansible.errors import AnsibleFilterError, AnsibleFilterTypeError
+
+
+UNIQUE_DATA = (([1, 3, 4, 2], [1, 3, 4, 2]),
+ ([1, 3, 2, 4, 2, 3], [1, 3, 2, 4]),
+ (['a', 'b', 'c', 'd'], ['a', 'b', 'c', 'd']),
+ (['a', 'a', 'd', 'b', 'a', 'd', 'c', 'b'], ['a', 'd', 'b', 'c']),
+ )
+
+TWO_SETS_DATA = (([1, 2], [3, 4], ([], sorted([1, 2]), sorted([1, 2, 3, 4]), sorted([1, 2, 3, 4]))),
+ ([1, 2, 3], [5, 3, 4], ([3], sorted([1, 2]), sorted([1, 2, 5, 4]), sorted([1, 2, 3, 4, 5]))),
+ (['a', 'b', 'c'], ['d', 'c', 'e'], (['c'], sorted(['a', 'b']), sorted(['a', 'b', 'd', 'e']), sorted(['a', 'b', 'c', 'e', 'd']))),
+ )
+
+env = Environment()
+
+
+@pytest.mark.parametrize('data, expected', UNIQUE_DATA)
+class TestUnique:
+ def test_unhashable(self, data, expected):
+ assert ms.unique(env, list(data)) == expected
+
+ def test_hashable(self, data, expected):
+ assert ms.unique(env, tuple(data)) == expected
+
+
+@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
+class TestIntersect:
+ def test_unhashable(self, dataset1, dataset2, expected):
+ assert sorted(ms.intersect(env, list(dataset1), list(dataset2))) == expected[0]
+
+ def test_hashable(self, dataset1, dataset2, expected):
+ assert sorted(ms.intersect(env, tuple(dataset1), tuple(dataset2))) == expected[0]
+
+
+@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
+class TestDifference:
+ def test_unhashable(self, dataset1, dataset2, expected):
+ assert sorted(ms.difference(env, list(dataset1), list(dataset2))) == expected[1]
+
+ def test_hashable(self, dataset1, dataset2, expected):
+ assert sorted(ms.difference(env, tuple(dataset1), tuple(dataset2))) == expected[1]
+
+
+@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
+class TestSymmetricDifference:
+ def test_unhashable(self, dataset1, dataset2, expected):
+ assert sorted(ms.symmetric_difference(env, list(dataset1), list(dataset2))) == expected[2]
+
+ def test_hashable(self, dataset1, dataset2, expected):
+ assert sorted(ms.symmetric_difference(env, tuple(dataset1), tuple(dataset2))) == expected[2]
+
+
+class TestLogarithm:
+ def test_log_non_number(self):
+ # Message changed in python3.6
+ with pytest.raises(AnsibleFilterTypeError, match='log\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
+ ms.logarithm('a')
+ with pytest.raises(AnsibleFilterTypeError, match='log\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
+ ms.logarithm(10, base='a')
+
+ def test_log_ten(self):
+ assert ms.logarithm(10, 10) == 1.0
+ assert ms.logarithm(69, 10) * 1000 // 1 == 1838
+
+ def test_log_natural(self):
+ assert ms.logarithm(69) * 1000 // 1 == 4234
+
+ def test_log_two(self):
+ assert ms.logarithm(69, 2) * 1000 // 1 == 6108
+
+
+class TestPower:
+ def test_power_non_number(self):
+ # Message changed in python3.6
+ with pytest.raises(AnsibleFilterTypeError, match='pow\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
+ ms.power('a', 10)
+
+ with pytest.raises(AnsibleFilterTypeError, match='pow\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
+ ms.power(10, 'a')
+
+ def test_power_squared(self):
+ assert ms.power(10, 2) == 100
+
+ def test_power_cubed(self):
+ assert ms.power(10, 3) == 1000
+
+
+class TestInversePower:
+ def test_root_non_number(self):
+ # Messages differed in python-2.6, python-2.7-3.5, and python-3.6+
+ with pytest.raises(AnsibleFilterTypeError, match="root\\(\\) can only be used on numbers:"
+ " (invalid literal for float\\(\\): a"
+ "|could not convert string to float: a"
+ "|could not convert string to float: 'a')"):
+ ms.inversepower(10, 'a')
+
+ with pytest.raises(AnsibleFilterTypeError, match="root\\(\\) can only be used on numbers: (a float is required|must be real number, not str)"):
+ ms.inversepower('a', 10)
+
+ def test_square_root(self):
+ assert ms.inversepower(100) == 10
+ assert ms.inversepower(100, 2) == 10
+
+ def test_cube_root(self):
+ assert ms.inversepower(27, 3) == 3
+
+
+class TestRekeyOnMember():
+ # (Input data structure, member to rekey on, expected return)
+ VALID_ENTRIES = (
+ ([{"proto": "eigrp", "state": "enabled"}, {"proto": "ospf", "state": "enabled"}],
+ 'proto',
+ {'eigrp': {'state': 'enabled', 'proto': 'eigrp'}, 'ospf': {'state': 'enabled', 'proto': 'ospf'}}),
+ ({'eigrp': {"proto": "eigrp", "state": "enabled"}, 'ospf': {"proto": "ospf", "state": "enabled"}},
+ 'proto',
+ {'eigrp': {'state': 'enabled', 'proto': 'eigrp'}, 'ospf': {'state': 'enabled', 'proto': 'ospf'}}),
+ )
+
+ # (Input data structure, member to rekey on, expected error message)
+ INVALID_ENTRIES = (
+ # Fail when key is not found
+ (AnsibleFilterError, [{"proto": "eigrp", "state": "enabled"}], 'invalid_key', "Key invalid_key was not found"),
+ (AnsibleFilterError, {"eigrp": {"proto": "eigrp", "state": "enabled"}}, 'invalid_key', "Key invalid_key was not found"),
+ # Fail when key is duplicated
+ (AnsibleFilterError, [{"proto": "eigrp"}, {"proto": "ospf"}, {"proto": "ospf"}],
+ 'proto', 'Key ospf is not unique, cannot correctly turn into dict'),
+ # Fail when value is not a dict
+ (AnsibleFilterTypeError, ["string"], 'proto', "List item is not a valid dict"),
+ (AnsibleFilterTypeError, [123], 'proto', "List item is not a valid dict"),
+ (AnsibleFilterTypeError, [[{'proto': 1}]], 'proto', "List item is not a valid dict"),
+ # Fail when we do not send a dict or list
+ (AnsibleFilterTypeError, "string", 'proto', "Type is not a valid list, set, or dict"),
+ (AnsibleFilterTypeError, 123, 'proto', "Type is not a valid list, set, or dict"),
+ )
+
+ @pytest.mark.parametrize("list_original, key, expected", VALID_ENTRIES)
+ def test_rekey_on_member_success(self, list_original, key, expected):
+ assert ms.rekey_on_member(list_original, key) == expected
+
+ @pytest.mark.parametrize("expected_exception_type, list_original, key, expected", INVALID_ENTRIES)
+ def test_fail_rekey_on_member(self, expected_exception_type, list_original, key, expected):
+ with pytest.raises(expected_exception_type) as err:
+ ms.rekey_on_member(list_original, key)
+
+ assert err.value.message == expected
+
+ def test_duplicate_strategy_overwrite(self):
+ list_original = ({'proto': 'eigrp', 'id': 1}, {'proto': 'ospf', 'id': 2}, {'proto': 'eigrp', 'id': 3})
+ expected = {'eigrp': {'proto': 'eigrp', 'id': 3}, 'ospf': {'proto': 'ospf', 'id': 2}}
+ assert ms.rekey_on_member(list_original, 'proto', duplicates='overwrite') == expected
diff --git a/test/units/plugins/inventory/__init__.py b/test/units/plugins/inventory/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/inventory/__init__.py
diff --git a/test/units/plugins/inventory/test_constructed.py b/test/units/plugins/inventory/test_constructed.py
new file mode 100644
index 0000000..581e025
--- /dev/null
+++ b/test/units/plugins/inventory/test_constructed.py
@@ -0,0 +1,337 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Alan Rominger <arominge@redhat.net>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+
+from ansible.errors import AnsibleParserError
+from ansible.plugins.inventory.constructed import InventoryModule
+from ansible.inventory.data import InventoryData
+from ansible.template import Templar
+
+
+@pytest.fixture()
+def inventory_module():
+ r = InventoryModule()
+ r.inventory = InventoryData()
+ r.templar = Templar(None)
+ r._options = {'leading_separator': True}
+ return r
+
+
+def test_group_by_value_only(inventory_module):
+ inventory_module.inventory.add_host('foohost')
+ inventory_module.inventory.set_variable('foohost', 'bar', 'my_group_name')
+ host = inventory_module.inventory.get_host('foohost')
+ keyed_groups = [
+ {
+ 'prefix': '',
+ 'separator': '',
+ 'key': 'bar'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ assert 'my_group_name' in inventory_module.inventory.groups
+ group = inventory_module.inventory.groups['my_group_name']
+ assert group.hosts == [host]
+
+
+def test_keyed_group_separator(inventory_module):
+ inventory_module.inventory.add_host('farm')
+ inventory_module.inventory.set_variable('farm', 'farmer', 'mcdonald')
+ inventory_module.inventory.set_variable('farm', 'barn', {'cow': 'betsy'})
+ host = inventory_module.inventory.get_host('farm')
+ keyed_groups = [
+ {
+ 'prefix': 'farmer',
+ 'separator': '_old_',
+ 'key': 'farmer'
+ },
+ {
+ 'separator': 'mmmmmmmmmm',
+ 'key': 'barn'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ for group_name in ('farmer_old_mcdonald', 'mmmmmmmmmmcowmmmmmmmmmmbetsy'):
+ assert group_name in inventory_module.inventory.groups
+ group = inventory_module.inventory.groups[group_name]
+ assert group.hosts == [host]
+
+
+def test_keyed_group_empty_construction(inventory_module):
+ inventory_module.inventory.add_host('farm')
+ inventory_module.inventory.set_variable('farm', 'barn', {})
+ host = inventory_module.inventory.get_host('farm')
+ keyed_groups = [
+ {
+ 'separator': 'mmmmmmmmmm',
+ 'key': 'barn'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=True
+ )
+ assert host.groups == []
+
+
+def test_keyed_group_host_confusion(inventory_module):
+ inventory_module.inventory.add_host('cow')
+ inventory_module.inventory.add_group('cow')
+ host = inventory_module.inventory.get_host('cow')
+ host.vars['species'] = 'cow'
+ keyed_groups = [
+ {
+ 'separator': '',
+ 'prefix': '',
+ 'key': 'species'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=True
+ )
+ group = inventory_module.inventory.groups['cow']
+ # group cow has host of cow
+ assert group.hosts == [host]
+
+
+def test_keyed_parent_groups(inventory_module):
+ inventory_module.inventory.add_host('web1')
+ inventory_module.inventory.add_host('web2')
+ inventory_module.inventory.set_variable('web1', 'region', 'japan')
+ inventory_module.inventory.set_variable('web2', 'region', 'japan')
+ host1 = inventory_module.inventory.get_host('web1')
+ host2 = inventory_module.inventory.get_host('web2')
+ keyed_groups = [
+ {
+ 'prefix': 'region',
+ 'key': 'region',
+ 'parent_group': 'region_list'
+ }
+ ]
+ for host in [host1, host2]:
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ assert 'region_japan' in inventory_module.inventory.groups
+ assert 'region_list' in inventory_module.inventory.groups
+ region_group = inventory_module.inventory.groups['region_japan']
+ all_regions = inventory_module.inventory.groups['region_list']
+ assert all_regions.child_groups == [region_group]
+ assert region_group.hosts == [host1, host2]
+
+
+def test_parent_group_templating(inventory_module):
+ inventory_module.inventory.add_host('cow')
+ inventory_module.inventory.set_variable('cow', 'sound', 'mmmmmmmmmm')
+ inventory_module.inventory.set_variable('cow', 'nickname', 'betsy')
+ host = inventory_module.inventory.get_host('cow')
+ keyed_groups = [
+ {
+ 'key': 'sound',
+ 'prefix': 'sound',
+ 'parent_group': '{{ nickname }}'
+ },
+ {
+ 'key': 'nickname',
+ 'prefix': '',
+ 'separator': '',
+ 'parent_group': 'nickname' # statically-named parent group, conflicting with hostvar
+ },
+ {
+ 'key': 'nickname',
+ 'separator': '',
+ 'parent_group': '{{ location | default("field") }}'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=True
+ )
+ # first keyed group, "betsy" is a parent group name dynamically generated
+ betsys_group = inventory_module.inventory.groups['betsy']
+ assert [child.name for child in betsys_group.child_groups] == ['sound_mmmmmmmmmm']
+ # second keyed group, "nickname" is a statically-named root group
+ nicknames_group = inventory_module.inventory.groups['nickname']
+ assert [child.name for child in nicknames_group.child_groups] == ['betsy']
+ # second keyed group actually generated the parent group of the first keyed group
+ # assert that these are, in fact, the same object
+ assert nicknames_group.child_groups[0] == betsys_group
+ # second keyed group has two parents
+ locations_group = inventory_module.inventory.groups['field']
+ assert [child.name for child in locations_group.child_groups] == ['betsy']
+
+
+def test_parent_group_templating_error(inventory_module):
+ inventory_module.inventory.add_host('cow')
+ inventory_module.inventory.set_variable('cow', 'nickname', 'betsy')
+ host = inventory_module.inventory.get_host('cow')
+ keyed_groups = [
+ {
+ 'key': 'nickname',
+ 'separator': '',
+ 'parent_group': '{{ location.barn-yard }}'
+ }
+ ]
+ with pytest.raises(AnsibleParserError) as err_message:
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=True
+ )
+ assert 'Could not generate parent group' in err_message
+ # invalid parent group did not raise an exception with strict=False
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ # assert group was never added with invalid parent
+ assert 'betsy' not in inventory_module.inventory.groups
+
+
+def test_keyed_group_exclusive_argument(inventory_module):
+ inventory_module.inventory.add_host('cow')
+ inventory_module.inventory.set_variable('cow', 'nickname', 'betsy')
+ host = inventory_module.inventory.get_host('cow')
+ keyed_groups = [
+ {
+ 'key': 'tag',
+ 'separator': '_',
+ 'default_value': 'default_value_name',
+ 'trailing_separator': True
+ }
+ ]
+ with pytest.raises(AnsibleParserError) as err_message:
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=True
+ )
+ assert 'parameters are mutually exclusive' in err_message
+
+
+def test_keyed_group_empty_value(inventory_module):
+ inventory_module.inventory.add_host('server0')
+ inventory_module.inventory.set_variable('server0', 'tags', {'environment': 'prod', 'status': ''})
+ host = inventory_module.inventory.get_host('server0')
+ keyed_groups = [
+ {
+ 'prefix': 'tag',
+ 'separator': '_',
+ 'key': 'tags'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ for group_name in ('tag_environment_prod', 'tag_status_'):
+ assert group_name in inventory_module.inventory.groups
+
+
+def test_keyed_group_dict_with_default_value(inventory_module):
+ inventory_module.inventory.add_host('server0')
+ inventory_module.inventory.set_variable('server0', 'tags', {'environment': 'prod', 'status': ''})
+ host = inventory_module.inventory.get_host('server0')
+ keyed_groups = [
+ {
+ 'prefix': 'tag',
+ 'separator': '_',
+ 'key': 'tags',
+ 'default_value': 'running'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ for group_name in ('tag_environment_prod', 'tag_status_running'):
+ assert group_name in inventory_module.inventory.groups
+
+
+def test_keyed_group_str_no_default_value(inventory_module):
+ inventory_module.inventory.add_host('server0')
+ inventory_module.inventory.set_variable('server0', 'tags', '')
+ host = inventory_module.inventory.get_host('server0')
+ keyed_groups = [
+ {
+ 'prefix': 'tag',
+ 'separator': '_',
+ 'key': 'tags'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ # when the value is an empty string. this group is not generated
+ assert "tag_" not in inventory_module.inventory.groups
+
+
+def test_keyed_group_str_with_default_value(inventory_module):
+ inventory_module.inventory.add_host('server0')
+ inventory_module.inventory.set_variable('server0', 'tags', '')
+ host = inventory_module.inventory.get_host('server0')
+ keyed_groups = [
+ {
+ 'prefix': 'tag',
+ 'separator': '_',
+ 'key': 'tags',
+ 'default_value': 'running'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ assert "tag_running" in inventory_module.inventory.groups
+
+
+def test_keyed_group_list_with_default_value(inventory_module):
+ inventory_module.inventory.add_host('server0')
+ inventory_module.inventory.set_variable('server0', 'tags', ['test', ''])
+ host = inventory_module.inventory.get_host('server0')
+ keyed_groups = [
+ {
+ 'prefix': 'tag',
+ 'separator': '_',
+ 'key': 'tags',
+ 'default_value': 'prod'
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ for group_name in ('tag_test', 'tag_prod'):
+ assert group_name in inventory_module.inventory.groups
+
+
+def test_keyed_group_with_trailing_separator(inventory_module):
+ inventory_module.inventory.add_host('server0')
+ inventory_module.inventory.set_variable('server0', 'tags', {'environment': 'prod', 'status': ''})
+ host = inventory_module.inventory.get_host('server0')
+ keyed_groups = [
+ {
+ 'prefix': 'tag',
+ 'separator': '_',
+ 'key': 'tags',
+ 'trailing_separator': False
+ }
+ ]
+ inventory_module._add_host_to_keyed_groups(
+ keyed_groups, host.vars, host.name, strict=False
+ )
+ for group_name in ('tag_environment_prod', 'tag_status'):
+ assert group_name in inventory_module.inventory.groups
diff --git a/test/units/plugins/inventory/test_inventory.py b/test/units/plugins/inventory/test_inventory.py
new file mode 100644
index 0000000..df24607
--- /dev/null
+++ b/test/units/plugins/inventory/test_inventory.py
@@ -0,0 +1,208 @@
+# Copyright 2015 Abhijit Menon-Sen <ams@2ndQuadrant.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import string
+import textwrap
+
+from unittest import mock
+
+from ansible import constants as C
+from units.compat import unittest
+from ansible.module_utils.six import string_types
+from ansible.module_utils._text import to_text
+from units.mock.path import mock_unfrackpath_noop
+
+from ansible.inventory.manager import InventoryManager, split_host_pattern
+
+from units.mock.loader import DictDataLoader
+
+
+class TestInventory(unittest.TestCase):
+
+ patterns = {
+ 'a': ['a'],
+ 'a, b': ['a', 'b'],
+ 'a , b': ['a', 'b'],
+ ' a,b ,c[1:2] ': ['a', 'b', 'c[1:2]'],
+ '9a01:7f8:191:7701::9': ['9a01:7f8:191:7701::9'],
+ '9a01:7f8:191:7701::9,9a01:7f8:191:7701::9': ['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9'],
+ '9a01:7f8:191:7701::9,9a01:7f8:191:7701::9,foo': ['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9', 'foo'],
+ 'foo[1:2]': ['foo[1:2]'],
+ 'a::b': ['a::b'],
+ 'a:b': ['a', 'b'],
+ ' a : b ': ['a', 'b'],
+ 'foo:bar:baz[1:2]': ['foo', 'bar', 'baz[1:2]'],
+ 'a,,b': ['a', 'b'],
+ 'a, ,b,,c, ,': ['a', 'b', 'c'],
+ ',': [],
+ '': [],
+ }
+
+ pattern_lists = [
+ [['a'], ['a']],
+ [['a', 'b'], ['a', 'b']],
+ [['a, b'], ['a', 'b']],
+ [['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9,foo'],
+ ['9a01:7f8:191:7701::9', '9a01:7f8:191:7701::9', 'foo']]
+ ]
+
+ # pattern_string: [ ('base_pattern', (a,b)), ['x','y','z'] ]
+ # a,b are the bounds of the subscript; x..z are the results of the subscript
+ # when applied to string.ascii_letters.
+
+ subscripts = {
+ 'a': [('a', None), list(string.ascii_letters)],
+ 'a[0]': [('a', (0, None)), ['a']],
+ 'a[1]': [('a', (1, None)), ['b']],
+ 'a[2:3]': [('a', (2, 3)), ['c', 'd']],
+ 'a[-1]': [('a', (-1, None)), ['Z']],
+ 'a[-2]': [('a', (-2, None)), ['Y']],
+ 'a[48:]': [('a', (48, -1)), ['W', 'X', 'Y', 'Z']],
+ 'a[49:]': [('a', (49, -1)), ['X', 'Y', 'Z']],
+ 'a[1:]': [('a', (1, -1)), list(string.ascii_letters[1:])],
+ }
+
+ ranges_to_expand = {
+ 'a[1:2]': ['a1', 'a2'],
+ 'a[1:10:2]': ['a1', 'a3', 'a5', 'a7', 'a9'],
+ 'a[a:b]': ['aa', 'ab'],
+ 'a[a:i:3]': ['aa', 'ad', 'ag'],
+ 'a[a:b][c:d]': ['aac', 'aad', 'abc', 'abd'],
+ 'a[0:1][2:3]': ['a02', 'a03', 'a12', 'a13'],
+ 'a[a:b][2:3]': ['aa2', 'aa3', 'ab2', 'ab3'],
+ }
+
+ def setUp(self):
+ fake_loader = DictDataLoader({})
+
+ self.i = InventoryManager(loader=fake_loader, sources=[None])
+
+ def test_split_patterns(self):
+
+ for p in self.patterns:
+ r = self.patterns[p]
+ self.assertEqual(r, split_host_pattern(p))
+
+ for p, r in self.pattern_lists:
+ self.assertEqual(r, split_host_pattern(p))
+
+ def test_ranges(self):
+
+ for s in self.subscripts:
+ r = self.subscripts[s]
+ self.assertEqual(r[0], self.i._split_subscript(s))
+ self.assertEqual(
+ r[1],
+ self.i._apply_subscript(
+ list(string.ascii_letters),
+ r[0][1]
+ )
+ )
+
+
+class TestInventoryPlugins(unittest.TestCase):
+
+ def test_empty_inventory(self):
+ inventory = self._get_inventory('')
+
+ self.assertIn('all', inventory.groups)
+ self.assertIn('ungrouped', inventory.groups)
+ self.assertFalse(inventory.groups['all'].get_hosts())
+ self.assertFalse(inventory.groups['ungrouped'].get_hosts())
+
+ def test_ini(self):
+ self._test_default_groups("""
+ host1
+ host2
+ host3
+ [servers]
+ host3
+ host4
+ host5
+ """)
+
+ def test_ini_explicit_ungrouped(self):
+ self._test_default_groups("""
+ [ungrouped]
+ host1
+ host2
+ host3
+ [servers]
+ host3
+ host4
+ host5
+ """)
+
+ def test_ini_variables_stringify(self):
+ values = ['string', 'no', 'No', 'false', 'FALSE', [], False, 0]
+
+ inventory_content = "host1 "
+ inventory_content += ' '.join(['var%s=%s' % (i, to_text(x)) for i, x in enumerate(values)])
+ inventory = self._get_inventory(inventory_content)
+
+ variables = inventory.get_host('host1').vars
+ for i in range(len(values)):
+ if isinstance(values[i], string_types):
+ self.assertIsInstance(variables['var%s' % i], string_types)
+ else:
+ self.assertIsInstance(variables['var%s' % i], type(values[i]))
+
+ @mock.patch('ansible.inventory.manager.unfrackpath', mock_unfrackpath_noop)
+ @mock.patch('os.path.exists', lambda x: True)
+ @mock.patch('os.access', lambda x, y: True)
+ def test_yaml_inventory(self, filename="test.yaml"):
+ inventory_content = {filename: textwrap.dedent("""\
+ ---
+ all:
+ hosts:
+ test1:
+ test2:
+ """)}
+ C.INVENTORY_ENABLED = ['yaml']
+ fake_loader = DictDataLoader(inventory_content)
+ im = InventoryManager(loader=fake_loader, sources=filename)
+ self.assertTrue(im._inventory.hosts)
+ self.assertIn('test1', im._inventory.hosts)
+ self.assertIn('test2', im._inventory.hosts)
+ self.assertIn(im._inventory.get_host('test1'), im._inventory.groups['all'].hosts)
+ self.assertIn(im._inventory.get_host('test2'), im._inventory.groups['all'].hosts)
+ self.assertEqual(len(im._inventory.groups['all'].hosts), 2)
+ self.assertIn(im._inventory.get_host('test1'), im._inventory.groups['ungrouped'].hosts)
+ self.assertIn(im._inventory.get_host('test2'), im._inventory.groups['ungrouped'].hosts)
+ self.assertEqual(len(im._inventory.groups['ungrouped'].hosts), 2)
+
+ def _get_inventory(self, inventory_content):
+
+ fake_loader = DictDataLoader({__file__: inventory_content})
+
+ return InventoryManager(loader=fake_loader, sources=[__file__])
+
+ def _test_default_groups(self, inventory_content):
+ inventory = self._get_inventory(inventory_content)
+
+ self.assertIn('all', inventory.groups)
+ self.assertIn('ungrouped', inventory.groups)
+ all_hosts = set(host.name for host in inventory.groups['all'].get_hosts())
+ self.assertEqual(set(['host1', 'host2', 'host3', 'host4', 'host5']), all_hosts)
+ ungrouped_hosts = set(host.name for host in inventory.groups['ungrouped'].get_hosts())
+ self.assertEqual(set(['host1', 'host2']), ungrouped_hosts)
+ servers_hosts = set(host.name for host in inventory.groups['servers'].get_hosts())
+ self.assertEqual(set(['host3', 'host4', 'host5']), servers_hosts)
diff --git a/test/units/plugins/inventory/test_script.py b/test/units/plugins/inventory/test_script.py
new file mode 100644
index 0000000..9f75199
--- /dev/null
+++ b/test/units/plugins/inventory/test_script.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2017 Chris Meyers <cmeyers@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+from unittest import mock
+
+from ansible import constants as C
+from ansible.errors import AnsibleError
+from ansible.plugins.loader import PluginLoader
+from units.compat import unittest
+from ansible.module_utils._text import to_bytes, to_native
+
+
+class TestInventoryModule(unittest.TestCase):
+
+ def setUp(self):
+
+ class Inventory():
+ cache = dict()
+
+ class PopenResult():
+ returncode = 0
+ stdout = b""
+ stderr = b""
+
+ def communicate(self):
+ return (self.stdout, self.stderr)
+
+ self.popen_result = PopenResult()
+ self.inventory = Inventory()
+ self.loader = mock.MagicMock()
+ self.loader.load = mock.MagicMock()
+
+ inv_loader = PluginLoader('InventoryModule', 'ansible.plugins.inventory', C.DEFAULT_INVENTORY_PLUGIN_PATH, 'inventory_plugins')
+ self.inventory_module = inv_loader.get('script')
+ self.inventory_module.set_options()
+
+ def register_patch(name):
+ patcher = mock.patch(name)
+ self.addCleanup(patcher.stop)
+ return patcher.start()
+
+ self.popen = register_patch('subprocess.Popen')
+ self.popen.return_value = self.popen_result
+
+ self.BaseInventoryPlugin = register_patch('ansible.plugins.inventory.BaseInventoryPlugin')
+ self.BaseInventoryPlugin.get_cache_prefix.return_value = 'abc123'
+
+ def test_parse_subprocess_path_not_found_fail(self):
+ self.popen.side_effect = OSError("dummy text")
+
+ with pytest.raises(AnsibleError) as e:
+ self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
+ assert e.value.message == "problem running /foo/bar/foobar.py --list (dummy text)"
+
+ def test_parse_subprocess_err_code_fail(self):
+ self.popen_result.stdout = to_bytes(u"fooébar", errors='surrogate_escape')
+ self.popen_result.stderr = to_bytes(u"dummyédata")
+
+ self.popen_result.returncode = 1
+
+ with pytest.raises(AnsibleError) as e:
+ self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
+ assert e.value.message == to_native("Inventory script (/foo/bar/foobar.py) had an execution error: "
+ "dummyédata\n ")
+
+ def test_parse_utf8_fail(self):
+ self.popen_result.returncode = 0
+ self.popen_result.stderr = to_bytes("dummyédata")
+ self.loader.load.side_effect = TypeError('obj must be string')
+
+ with pytest.raises(AnsibleError) as e:
+ self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
+ assert e.value.message == to_native("failed to parse executable inventory script results from "
+ "/foo/bar/foobar.py: obj must be string\ndummyédata\n")
+
+ def test_parse_dict_fail(self):
+ self.popen_result.returncode = 0
+ self.popen_result.stderr = to_bytes("dummyédata")
+ self.loader.load.return_value = 'i am not a dict'
+
+ with pytest.raises(AnsibleError) as e:
+ self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
+ assert e.value.message == to_native("failed to parse executable inventory script results from "
+ "/foo/bar/foobar.py: needs to be a json dict\ndummyédata\n")
diff --git a/test/units/plugins/loader_fixtures/__init__.py b/test/units/plugins/loader_fixtures/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/loader_fixtures/__init__.py
diff --git a/test/units/plugins/loader_fixtures/import_fixture.py b/test/units/plugins/loader_fixtures/import_fixture.py
new file mode 100644
index 0000000..8112733
--- /dev/null
+++ b/test/units/plugins/loader_fixtures/import_fixture.py
@@ -0,0 +1,9 @@
+# Nothing to see here, this file is just empty to support a imp.load_source
+# without doing anything
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+
+class test:
+ def __init__(self, *args, **kwargs):
+ pass
diff --git a/test/units/plugins/lookup/__init__.py b/test/units/plugins/lookup/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/lookup/__init__.py
diff --git a/test/units/plugins/lookup/test_env.py b/test/units/plugins/lookup/test_env.py
new file mode 100644
index 0000000..5d9713f
--- /dev/null
+++ b/test/units/plugins/lookup/test_env.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Copyright: (c) 2019, Abhay Kadam <abhaykadam88@gmail.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+
+from ansible.plugins.loader import lookup_loader
+
+
+@pytest.mark.parametrize('env_var,exp_value', [
+ ('foo', 'bar'),
+ ('equation', 'a=b*100')
+])
+def test_env_var_value(monkeypatch, env_var, exp_value):
+ monkeypatch.setattr('ansible.utils.py3compat.environ.get', lambda x, y: exp_value)
+
+ env_lookup = lookup_loader.get('env')
+ retval = env_lookup.run([env_var], None)
+ assert retval == [exp_value]
+
+
+@pytest.mark.parametrize('env_var,exp_value', [
+ ('simple_var', 'alpha-β-gamma'),
+ ('the_var', 'ãnˈsiβle')
+])
+def test_utf8_env_var_value(monkeypatch, env_var, exp_value):
+ monkeypatch.setattr('ansible.utils.py3compat.environ.get', lambda x, y: exp_value)
+
+ env_lookup = lookup_loader.get('env')
+ retval = env_lookup.run([env_var], None)
+ assert retval == [exp_value]
diff --git a/test/units/plugins/lookup/test_ini.py b/test/units/plugins/lookup/test_ini.py
new file mode 100644
index 0000000..b2d883c
--- /dev/null
+++ b/test/units/plugins/lookup/test_ini.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from units.compat import unittest
+from ansible.plugins.lookup.ini import _parse_params
+
+
+class TestINILookup(unittest.TestCase):
+
+ # Currently there isn't a new-style
+ old_style_params_data = (
+ # Simple case
+ dict(
+ term=u'keyA section=sectionA file=/path/to/file',
+ expected=[u'file=/path/to/file', u'keyA', u'section=sectionA'],
+ ),
+ dict(
+ term=u'keyB section=sectionB with space file=/path/with/embedded spaces and/file',
+ expected=[u'file=/path/with/embedded spaces and/file', u'keyB', u'section=sectionB with space'],
+ ),
+ dict(
+ term=u'keyC section=sectionC file=/path/with/equals/cn=com.ansible',
+ expected=[u'file=/path/with/equals/cn=com.ansible', u'keyC', u'section=sectionC'],
+ ),
+ dict(
+ term=u'keyD section=sectionD file=/path/with space and/equals/cn=com.ansible',
+ expected=[u'file=/path/with space and/equals/cn=com.ansible', u'keyD', u'section=sectionD'],
+ ),
+ dict(
+ term=u'keyE section=sectionE file=/path/with/unicode/くらとみ/file',
+ expected=[u'file=/path/with/unicode/くらとみ/file', u'keyE', u'section=sectionE'],
+ ),
+ dict(
+ term=u'keyF section=sectionF file=/path/with/utf 8 and spaces/くらとみ/file',
+ expected=[u'file=/path/with/utf 8 and spaces/くらとみ/file', u'keyF', u'section=sectionF'],
+ ),
+ )
+
+ def test_parse_parameters(self):
+ pvals = {'file': '', 'section': '', 'key': '', 'type': '', 're': '', 'default': '', 'encoding': ''}
+ for testcase in self.old_style_params_data:
+ # print(testcase)
+ params = _parse_params(testcase['term'], pvals)
+ params.sort()
+ self.assertEqual(params, testcase['expected'])
diff --git a/test/units/plugins/lookup/test_password.py b/test/units/plugins/lookup/test_password.py
new file mode 100644
index 0000000..15207b2
--- /dev/null
+++ b/test/units/plugins/lookup/test_password.py
@@ -0,0 +1,577 @@
+# -*- coding: utf-8 -*-
+# (c) 2015, Toshio Kuratomi <tkuratomi@ansible.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+try:
+ import passlib
+ from passlib.handlers import pbkdf2
+except ImportError:
+ passlib = None
+ pbkdf2 = None
+
+import pytest
+
+from units.mock.loader import DictDataLoader
+
+from units.compat import unittest
+from unittest.mock import mock_open, patch
+from ansible.errors import AnsibleError
+from ansible.module_utils.six import text_type
+from ansible.module_utils.six.moves import builtins
+from ansible.module_utils._text import to_bytes
+from ansible.plugins.loader import PluginLoader, lookup_loader
+from ansible.plugins.lookup import password
+
+
+DEFAULT_LENGTH = 20
+DEFAULT_CHARS = sorted([u'ascii_letters', u'digits', u".,:-_"])
+DEFAULT_CANDIDATE_CHARS = u'.,:-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
+
+# Currently there isn't a new-style
+old_style_params_data = (
+ # Simple case
+ dict(
+ term=u'/path/to/file',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+
+ # Special characters in path
+ dict(
+ term=u'/path/with/embedded spaces and/file',
+ filename=u'/path/with/embedded spaces and/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+ dict(
+ term=u'/path/with/equals/cn=com.ansible',
+ filename=u'/path/with/equals/cn=com.ansible',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+ dict(
+ term=u'/path/with/unicode/くらとみ/file',
+ filename=u'/path/with/unicode/くらとみ/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+
+ # Mix several special chars
+ dict(
+ term=u'/path/with/utf 8 and spaces/くらとみ/file',
+ filename=u'/path/with/utf 8 and spaces/くらとみ/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+ dict(
+ term=u'/path/with/encoding=unicode/くらとみ/file',
+ filename=u'/path/with/encoding=unicode/くらとみ/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+ dict(
+ term=u'/path/with/encoding=unicode/くらとみ/and spaces file',
+ filename=u'/path/with/encoding=unicode/くらとみ/and spaces file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+
+ # Simple parameters
+ dict(
+ term=u'/path/to/file length=42',
+ filename=u'/path/to/file',
+ params=dict(length=42, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+ dict(
+ term=u'/path/to/file encrypt=pbkdf2_sha256',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt='pbkdf2_sha256', ident=None, chars=DEFAULT_CHARS, seed=None),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+ dict(
+ term=u'/path/to/file chars=abcdefghijklmnop',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=[u'abcdefghijklmnop'], seed=None),
+ candidate_chars=u'abcdefghijklmnop',
+ ),
+ dict(
+ term=u'/path/to/file chars=digits,abc,def',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'digits', u'abc', u'def']), seed=None),
+ candidate_chars=u'abcdef0123456789',
+ ),
+ dict(
+ term=u'/path/to/file seed=1',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=DEFAULT_CHARS, seed='1'),
+ candidate_chars=DEFAULT_CANDIDATE_CHARS,
+ ),
+
+ # Including comma in chars
+ dict(
+ term=u'/path/to/file chars=abcdefghijklmnop,,digits',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'abcdefghijklmnop', u',', u'digits']), seed=None),
+ candidate_chars=u',abcdefghijklmnop0123456789',
+ ),
+ dict(
+ term=u'/path/to/file chars=,,',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=[u','], seed=None),
+ candidate_chars=u',',
+ ),
+
+ # Including = in chars
+ dict(
+ term=u'/path/to/file chars=digits,=,,',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'digits', u'=', u',']), seed=None),
+ candidate_chars=u',=0123456789',
+ ),
+ dict(
+ term=u'/path/to/file chars=digits,abc=def',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'digits', u'abc=def']), seed=None),
+ candidate_chars=u'abc=def0123456789',
+ ),
+
+ # Including unicode in chars
+ dict(
+ term=u'/path/to/file chars=digits,くらとみ,,',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'digits', u'くらとみ', u',']), seed=None),
+ candidate_chars=u',0123456789くらとみ',
+ ),
+ # Including only unicode in chars
+ dict(
+ term=u'/path/to/file chars=くらとみ',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'くらとみ']), seed=None),
+ candidate_chars=u'くらとみ',
+ ),
+
+ # Include ':' in path
+ dict(
+ term=u'/path/to/file_with:colon chars=ascii_letters,digits',
+ filename=u'/path/to/file_with:colon',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None,
+ chars=sorted([u'ascii_letters', u'digits']), seed=None),
+ candidate_chars=u'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789',
+ ),
+
+ # Including special chars in both path and chars
+ # Special characters in path
+ dict(
+ term=u'/path/with/embedded spaces and/file chars=abc=def',
+ filename=u'/path/with/embedded spaces and/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=[u'abc=def'], seed=None),
+ candidate_chars=u'abc=def',
+ ),
+ dict(
+ term=u'/path/with/equals/cn=com.ansible chars=abc=def',
+ filename=u'/path/with/equals/cn=com.ansible',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=[u'abc=def'], seed=None),
+ candidate_chars=u'abc=def',
+ ),
+ dict(
+ term=u'/path/with/unicode/くらとみ/file chars=くらとみ',
+ filename=u'/path/with/unicode/くらとみ/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, ident=None, chars=[u'くらとみ'], seed=None),
+ candidate_chars=u'くらとみ',
+ ),
+)
+
+
+class TestParseParameters(unittest.TestCase):
+
+ def setUp(self):
+ self.fake_loader = DictDataLoader({'/path/to/somewhere': 'sdfsdf'})
+ self.password_lookup = lookup_loader.get('password')
+ self.password_lookup._loader = self.fake_loader
+
+ def test(self):
+ for testcase in old_style_params_data:
+ filename, params = self.password_lookup._parse_parameters(testcase['term'])
+ params['chars'].sort()
+ self.assertEqual(filename, testcase['filename'])
+ self.assertEqual(params, testcase['params'])
+
+ def test_unrecognized_value(self):
+ testcase = dict(term=u'/path/to/file chars=くらとみi sdfsdf',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, chars=[u'くらとみ']),
+ candidate_chars=u'くらとみ')
+ self.assertRaises(AnsibleError, self.password_lookup._parse_parameters, testcase['term'])
+
+ def test_invalid_params(self):
+ testcase = dict(term=u'/path/to/file chars=くらとみi somethign_invalid=123',
+ filename=u'/path/to/file',
+ params=dict(length=DEFAULT_LENGTH, encrypt=None, chars=[u'くらとみ']),
+ candidate_chars=u'くらとみ')
+ self.assertRaises(AnsibleError, self.password_lookup._parse_parameters, testcase['term'])
+
+
+class TestReadPasswordFile(unittest.TestCase):
+ def setUp(self):
+ self.os_path_exists = password.os.path.exists
+
+ def tearDown(self):
+ password.os.path.exists = self.os_path_exists
+
+ def test_no_password_file(self):
+ password.os.path.exists = lambda x: False
+ self.assertEqual(password._read_password_file(b'/nonexistent'), None)
+
+ def test_with_password_file(self):
+ password.os.path.exists = lambda x: True
+ with patch.object(builtins, 'open', mock_open(read_data=b'Testing\n')) as m:
+ self.assertEqual(password._read_password_file(b'/etc/motd'), u'Testing')
+
+
+class TestGenCandidateChars(unittest.TestCase):
+ def _assert_gen_candidate_chars(self, testcase):
+ expected_candidate_chars = testcase['candidate_chars']
+ params = testcase['params']
+ chars_spec = params['chars']
+ res = password._gen_candidate_chars(chars_spec)
+ self.assertEqual(res, expected_candidate_chars)
+
+ def test_gen_candidate_chars(self):
+ for testcase in old_style_params_data:
+ self._assert_gen_candidate_chars(testcase)
+
+
+class TestRandomPassword(unittest.TestCase):
+ def _assert_valid_chars(self, res, chars):
+ for res_char in res:
+ self.assertIn(res_char, chars)
+
+ def test_default(self):
+ res = password.random_password()
+ self.assertEqual(len(res), DEFAULT_LENGTH)
+ self.assertTrue(isinstance(res, text_type))
+ self._assert_valid_chars(res, DEFAULT_CANDIDATE_CHARS)
+
+ def test_zero_length(self):
+ res = password.random_password(length=0)
+ self.assertEqual(len(res), 0)
+ self.assertTrue(isinstance(res, text_type))
+ self._assert_valid_chars(res, u',')
+
+ def test_just_a_common(self):
+ res = password.random_password(length=1, chars=u',')
+ self.assertEqual(len(res), 1)
+ self.assertEqual(res, u',')
+
+ def test_free_will(self):
+ # A Rush and Spinal Tap reference twofer
+ res = password.random_password(length=11, chars=u'a')
+ self.assertEqual(len(res), 11)
+ self.assertEqual(res, 'aaaaaaaaaaa')
+ self._assert_valid_chars(res, u'a')
+
+ def test_unicode(self):
+ res = password.random_password(length=11, chars=u'くらとみ')
+ self._assert_valid_chars(res, u'くらとみ')
+ self.assertEqual(len(res), 11)
+
+ def test_seed(self):
+ pw1 = password.random_password(seed=1)
+ pw2 = password.random_password(seed=1)
+ pw3 = password.random_password(seed=2)
+ self.assertEqual(pw1, pw2)
+ self.assertNotEqual(pw1, pw3)
+
+ def test_gen_password(self):
+ for testcase in old_style_params_data:
+ params = testcase['params']
+ candidate_chars = testcase['candidate_chars']
+ params_chars_spec = password._gen_candidate_chars(params['chars'])
+ password_string = password.random_password(length=params['length'],
+ chars=params_chars_spec)
+ self.assertEqual(len(password_string),
+ params['length'],
+ msg='generated password=%s has length (%s) instead of expected length (%s)' %
+ (password_string, len(password_string), params['length']))
+
+ for char in password_string:
+ self.assertIn(char, candidate_chars,
+ msg='%s not found in %s from chars spect %s' %
+ (char, candidate_chars, params['chars']))
+
+
+class TestParseContent(unittest.TestCase):
+
+ def test_empty_password_file(self):
+ plaintext_password, salt = password._parse_content(u'')
+ self.assertEqual(plaintext_password, u'')
+ self.assertEqual(salt, None)
+
+ def test(self):
+ expected_content = u'12345678'
+ file_content = expected_content
+ plaintext_password, salt = password._parse_content(file_content)
+ self.assertEqual(plaintext_password, expected_content)
+ self.assertEqual(salt, None)
+
+ def test_with_salt(self):
+ expected_content = u'12345678 salt=87654321'
+ file_content = expected_content
+ plaintext_password, salt = password._parse_content(file_content)
+ self.assertEqual(plaintext_password, u'12345678')
+ self.assertEqual(salt, u'87654321')
+
+
+class TestFormatContent(unittest.TestCase):
+ def test_no_encrypt(self):
+ self.assertEqual(
+ password._format_content(password=u'hunter42',
+ salt=u'87654321',
+ encrypt=False),
+ u'hunter42 salt=87654321')
+
+ def test_no_encrypt_no_salt(self):
+ self.assertEqual(
+ password._format_content(password=u'hunter42',
+ salt=None,
+ encrypt=None),
+ u'hunter42')
+
+ def test_encrypt(self):
+ self.assertEqual(
+ password._format_content(password=u'hunter42',
+ salt=u'87654321',
+ encrypt='pbkdf2_sha256'),
+ u'hunter42 salt=87654321')
+
+ def test_encrypt_no_salt(self):
+ self.assertRaises(AssertionError, password._format_content, u'hunter42', None, 'pbkdf2_sha256')
+
+
+class TestWritePasswordFile(unittest.TestCase):
+ def setUp(self):
+ self.makedirs_safe = password.makedirs_safe
+ self.os_chmod = password.os.chmod
+ password.makedirs_safe = lambda path, mode: None
+ password.os.chmod = lambda path, mode: None
+
+ def tearDown(self):
+ password.makedirs_safe = self.makedirs_safe
+ password.os.chmod = self.os_chmod
+
+ def test_content_written(self):
+
+ with patch.object(builtins, 'open', mock_open()) as m:
+ password._write_password_file(b'/this/is/a/test/caf\xc3\xa9', u'Testing Café')
+
+ m.assert_called_once_with(b'/this/is/a/test/caf\xc3\xa9', 'wb')
+ m().write.assert_called_once_with(u'Testing Café\n'.encode('utf-8'))
+
+
+class BaseTestLookupModule(unittest.TestCase):
+ def setUp(self):
+ self.fake_loader = DictDataLoader({'/path/to/somewhere': 'sdfsdf'})
+ self.password_lookup = lookup_loader.get('password')
+ self.password_lookup._loader = self.fake_loader
+ self.os_path_exists = password.os.path.exists
+ self.os_open = password.os.open
+ password.os.open = lambda path, flag: None
+ self.os_close = password.os.close
+ password.os.close = lambda fd: None
+ self.os_remove = password.os.remove
+ password.os.remove = lambda path: None
+ self.makedirs_safe = password.makedirs_safe
+ password.makedirs_safe = lambda path, mode: None
+
+ def tearDown(self):
+ password.os.path.exists = self.os_path_exists
+ password.os.open = self.os_open
+ password.os.close = self.os_close
+ password.os.remove = self.os_remove
+ password.makedirs_safe = self.makedirs_safe
+
+
+class TestLookupModuleWithoutPasslib(BaseTestLookupModule):
+ @patch.object(PluginLoader, '_get_paths')
+ @patch('ansible.plugins.lookup.password._write_password_file')
+ def test_no_encrypt(self, mock_get_paths, mock_write_file):
+ mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three']
+
+ results = self.password_lookup.run([u'/path/to/somewhere'], None)
+
+ # FIXME: assert something useful
+ for result in results:
+ assert len(result) == DEFAULT_LENGTH
+ assert isinstance(result, text_type)
+
+ @patch.object(PluginLoader, '_get_paths')
+ @patch('ansible.plugins.lookup.password._write_password_file')
+ def test_password_already_created_no_encrypt(self, mock_get_paths, mock_write_file):
+ mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three']
+ password.os.path.exists = lambda x: x == to_bytes('/path/to/somewhere')
+
+ with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m:
+ results = self.password_lookup.run([u'/path/to/somewhere chars=anything'], None)
+
+ for result in results:
+ self.assertEqual(result, u'hunter42')
+
+ @patch.object(PluginLoader, '_get_paths')
+ @patch('ansible.plugins.lookup.password._write_password_file')
+ def test_only_a(self, mock_get_paths, mock_write_file):
+ mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three']
+
+ results = self.password_lookup.run([u'/path/to/somewhere chars=a'], None)
+ for result in results:
+ self.assertEqual(result, u'a' * DEFAULT_LENGTH)
+
+ @patch('time.sleep')
+ def test_lock_been_held(self, mock_sleep):
+ # pretend the lock file is here
+ password.os.path.exists = lambda x: True
+ try:
+ with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m:
+ # should timeout here
+ results = self.password_lookup.run([u'/path/to/somewhere chars=anything'], None)
+ self.fail("Lookup didn't timeout when lock already been held")
+ except AnsibleError:
+ pass
+
+ def test_lock_not_been_held(self):
+ # pretend now there is password file but no lock
+ password.os.path.exists = lambda x: x == to_bytes('/path/to/somewhere')
+ try:
+ with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m:
+ # should not timeout here
+ results = self.password_lookup.run([u'/path/to/somewhere chars=anything'], None)
+ except AnsibleError:
+ self.fail('Lookup timeouts when lock is free')
+
+ for result in results:
+ self.assertEqual(result, u'hunter42')
+
+
+@pytest.mark.skipif(passlib is None, reason='passlib must be installed to run these tests')
+class TestLookupModuleWithPasslib(BaseTestLookupModule):
+ def setUp(self):
+ super(TestLookupModuleWithPasslib, self).setUp()
+
+ # Different releases of passlib default to a different number of rounds
+ self.sha256 = passlib.registry.get_crypt_handler('pbkdf2_sha256')
+ sha256_for_tests = pbkdf2.create_pbkdf2_hash("sha256", 32, 20000)
+ passlib.registry.register_crypt_handler(sha256_for_tests, force=True)
+
+ def tearDown(self):
+ super(TestLookupModuleWithPasslib, self).tearDown()
+
+ passlib.registry.register_crypt_handler(self.sha256, force=True)
+
+ @patch.object(PluginLoader, '_get_paths')
+ @patch('ansible.plugins.lookup.password._write_password_file')
+ def test_encrypt(self, mock_get_paths, mock_write_file):
+ mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three']
+
+ results = self.password_lookup.run([u'/path/to/somewhere encrypt=pbkdf2_sha256'], None)
+
+ # pbkdf2 format plus hash
+ expected_password_length = 76
+
+ for result in results:
+ self.assertEqual(len(result), expected_password_length)
+ # result should have 5 parts split by '$'
+ str_parts = result.split('$', 5)
+
+ # verify the result is parseable by the passlib
+ crypt_parts = passlib.hash.pbkdf2_sha256.parsehash(result)
+
+ # verify it used the right algo type
+ self.assertEqual(str_parts[1], 'pbkdf2-sha256')
+
+ self.assertEqual(len(str_parts), 5)
+
+ # verify the string and parsehash agree on the number of rounds
+ self.assertEqual(int(str_parts[2]), crypt_parts['rounds'])
+ self.assertIsInstance(result, text_type)
+
+ @patch.object(PluginLoader, '_get_paths')
+ @patch('ansible.plugins.lookup.password._write_password_file')
+ def test_password_already_created_encrypt(self, mock_get_paths, mock_write_file):
+ mock_get_paths.return_value = ['/path/one', '/path/two', '/path/three']
+ password.os.path.exists = lambda x: x == to_bytes('/path/to/somewhere')
+
+ with patch.object(builtins, 'open', mock_open(read_data=b'hunter42 salt=87654321\n')) as m:
+ results = self.password_lookup.run([u'/path/to/somewhere chars=anything encrypt=pbkdf2_sha256'], None)
+ for result in results:
+ self.assertEqual(result, u'$pbkdf2-sha256$20000$ODc2NTQzMjE$Uikde0cv0BKaRaAXMrUQB.zvG4GmnjClwjghwIRf2gU')
+
+
+@pytest.mark.skipif(passlib is None, reason='passlib must be installed to run these tests')
+class TestLookupModuleWithPasslibWrappedAlgo(BaseTestLookupModule):
+ def setUp(self):
+ super(TestLookupModuleWithPasslibWrappedAlgo, self).setUp()
+ self.os_path_exists = password.os.path.exists
+
+ def tearDown(self):
+ super(TestLookupModuleWithPasslibWrappedAlgo, self).tearDown()
+ password.os.path.exists = self.os_path_exists
+
+ @patch('ansible.plugins.lookup.password._write_password_file')
+ def test_encrypt_wrapped_crypt_algo(self, mock_write_file):
+
+ password.os.path.exists = self.password_lookup._loader.path_exists
+ with patch.object(builtins, 'open', mock_open(read_data=self.password_lookup._loader._get_file_contents('/path/to/somewhere')[0])) as m:
+ results = self.password_lookup.run([u'/path/to/somewhere encrypt=ldap_sha256_crypt'], None)
+
+ wrapper = getattr(passlib.hash, 'ldap_sha256_crypt')
+
+ self.assertEqual(len(results), 1)
+ result = results[0]
+ self.assertIsInstance(result, text_type)
+
+ expected_password_length = 76
+ self.assertEqual(len(result), expected_password_length)
+
+ # result should have 5 parts split by '$'
+ str_parts = result.split('$')
+ self.assertEqual(len(str_parts), 5)
+
+ # verify the string and passlib agree on the number of rounds
+ self.assertEqual(str_parts[2], "rounds=%s" % wrapper.default_rounds)
+
+ # verify it used the right algo type
+ self.assertEqual(str_parts[0], '{CRYPT}')
+
+ # verify it used the right algo type
+ self.assertTrue(wrapper.verify(self.password_lookup._loader._get_file_contents('/path/to/somewhere')[0], result))
+
+ # verify a password with a non default rounds value
+ # generated with: echo test | mkpasswd -s --rounds 660000 -m sha-256 --salt testansiblepass.
+ hashpw = '{CRYPT}$5$rounds=660000$testansiblepass.$KlRSdA3iFXoPI.dEwh7AixiXW3EtCkLrlQvlYA2sluD'
+ self.assertTrue(wrapper.verify('test', hashpw))
diff --git a/test/units/plugins/lookup/test_url.py b/test/units/plugins/lookup/test_url.py
new file mode 100644
index 0000000..2aa77b3
--- /dev/null
+++ b/test/units/plugins/lookup/test_url.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# Copyright: (c) 2020, Sam Doran <sdoran@redhat.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+
+from ansible.plugins.loader import lookup_loader
+
+
+@pytest.mark.parametrize(
+ ('kwargs', 'agent'),
+ (
+ ({}, 'ansible-httpget'),
+ ({'http_agent': 'SuperFox'}, 'SuperFox'),
+ )
+)
+def test_user_agent(mocker, kwargs, agent):
+ mock_open_url = mocker.patch('ansible.plugins.lookup.url.open_url', side_effect=AttributeError('raised intentionally'))
+ url_lookup = lookup_loader.get('url')
+ with pytest.raises(AttributeError):
+ url_lookup.run(['https://nourl'], **kwargs)
+ assert 'http_agent' in mock_open_url.call_args.kwargs
+ assert mock_open_url.call_args.kwargs['http_agent'] == agent
diff --git a/test/units/plugins/shell/__init__.py b/test/units/plugins/shell/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/shell/__init__.py
diff --git a/test/units/plugins/shell/test_cmd.py b/test/units/plugins/shell/test_cmd.py
new file mode 100644
index 0000000..4c1a654
--- /dev/null
+++ b/test/units/plugins/shell/test_cmd.py
@@ -0,0 +1,19 @@
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import pytest
+
+from ansible.plugins.shell.cmd import ShellModule
+
+
+@pytest.mark.parametrize('s, expected', [
+ ['arg1', 'arg1'],
+ [None, '""'],
+ ['arg1 and 2', '^"arg1 and 2^"'],
+ ['malicious argument\\"&whoami', '^"malicious argument\\\\^"^&whoami^"'],
+ ['C:\\temp\\some ^%file% > nul', '^"C:\\temp\\some ^^^%file^% ^> nul^"']
+])
+def test_quote_args(s, expected):
+ cmd = ShellModule()
+ actual = cmd.quote(s)
+ assert actual == expected
diff --git a/test/units/plugins/shell/test_powershell.py b/test/units/plugins/shell/test_powershell.py
new file mode 100644
index 0000000..c94baab
--- /dev/null
+++ b/test/units/plugins/shell/test_powershell.py
@@ -0,0 +1,83 @@
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from ansible.plugins.shell.powershell import _parse_clixml, ShellModule
+
+
+def test_parse_clixml_empty():
+ empty = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04"></Objs>'
+ expected = b''
+ actual = _parse_clixml(empty)
+ assert actual == expected
+
+
+def test_parse_clixml_with_progress():
+ progress = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \
+ b'<Obj S="progress" RefId="0"><TN RefId="0"><T>System.Management.Automation.PSCustomObject</T><T>System.Object</T></TN><MS>' \
+ b'<I64 N="SourceId">1</I64><PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \
+ b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj></Objs>'
+ expected = b''
+ actual = _parse_clixml(progress)
+ assert actual == expected
+
+
+def test_parse_clixml_single_stream():
+ single_stream = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \
+ b'<S S="Error">fake : The term \'fake\' is not recognized as the name of a cmdlet. Check _x000D__x000A_</S>' \
+ b'<S S="Error">the spelling of the name, or if a path was included._x000D__x000A_</S>' \
+ b'<S S="Error">At line:1 char:1_x000D__x000A_</S>' \
+ b'<S S="Error">+ fake cmdlet_x000D__x000A_</S><S S="Error">+ ~~~~_x000D__x000A_</S>' \
+ b'<S S="Error"> + CategoryInfo : ObjectNotFound: (fake:String) [], CommandNotFoundException_x000D__x000A_</S>' \
+ b'<S S="Error"> + FullyQualifiedErrorId : CommandNotFoundException_x000D__x000A_</S><S S="Error"> _x000D__x000A_</S>' \
+ b'</Objs>'
+ expected = b"fake : The term 'fake' is not recognized as the name of a cmdlet. Check \r\n" \
+ b"the spelling of the name, or if a path was included.\r\n" \
+ b"At line:1 char:1\r\n" \
+ b"+ fake cmdlet\r\n" \
+ b"+ ~~~~\r\n" \
+ b" + CategoryInfo : ObjectNotFound: (fake:String) [], CommandNotFoundException\r\n" \
+ b" + FullyQualifiedErrorId : CommandNotFoundException\r\n "
+ actual = _parse_clixml(single_stream)
+ assert actual == expected
+
+
+def test_parse_clixml_multiple_streams():
+ multiple_stream = b'#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \
+ b'<S S="Error">fake : The term \'fake\' is not recognized as the name of a cmdlet. Check _x000D__x000A_</S>' \
+ b'<S S="Error">the spelling of the name, or if a path was included._x000D__x000A_</S>' \
+ b'<S S="Error">At line:1 char:1_x000D__x000A_</S>' \
+ b'<S S="Error">+ fake cmdlet_x000D__x000A_</S><S S="Error">+ ~~~~_x000D__x000A_</S>' \
+ b'<S S="Error"> + CategoryInfo : ObjectNotFound: (fake:String) [], CommandNotFoundException_x000D__x000A_</S>' \
+ b'<S S="Error"> + FullyQualifiedErrorId : CommandNotFoundException_x000D__x000A_</S><S S="Error"> _x000D__x000A_</S>' \
+ b'<S S="Info">hi info</S>' \
+ b'</Objs>'
+ expected = b"hi info"
+ actual = _parse_clixml(multiple_stream, stream="Info")
+ assert actual == expected
+
+
+def test_parse_clixml_multiple_elements():
+ multiple_elements = b'#< CLIXML\r\n#< CLIXML\r\n<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04">' \
+ b'<Obj S="progress" RefId="0"><TN RefId="0"><T>System.Management.Automation.PSCustomObject</T><T>System.Object</T></TN><MS>' \
+ b'<I64 N="SourceId">1</I64><PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \
+ b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj>' \
+ b'<S S="Error">Error 1</S></Objs>' \
+ b'<Objs Version="1.1.0.1" xmlns="http://schemas.microsoft.com/powershell/2004/04"><Obj S="progress" RefId="0">' \
+ b'<TN RefId="0"><T>System.Management.Automation.PSCustomObject</T><T>System.Object</T></TN><MS>' \
+ b'<I64 N="SourceId">1</I64><PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \
+ b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj>' \
+ b'<Obj S="progress" RefId="1"><TNRef RefId="0" /><MS><I64 N="SourceId">2</I64>' \
+ b'<PR N="Record"><AV>Preparing modules for first use.</AV><AI>0</AI><Nil />' \
+ b'<PI>-1</PI><PC>-1</PC><T>Completed</T><SR>-1</SR><SD> </SD></PR></MS></Obj>' \
+ b'<S S="Error">Error 2</S></Objs>'
+ expected = b"Error 1\r\nError 2"
+ actual = _parse_clixml(multiple_elements)
+ assert actual == expected
+
+
+def test_join_path_unc():
+ pwsh = ShellModule()
+ unc_path_parts = ['\\\\host\\share\\dir1\\\\dir2\\', '\\dir3/dir4', 'dir5', 'dir6\\']
+ expected = '\\\\host\\share\\dir1\\dir2\\dir3\\dir4\\dir5\\dir6'
+ actual = pwsh.join_path(*unc_path_parts)
+ assert actual == expected
diff --git a/test/units/plugins/strategy/__init__.py b/test/units/plugins/strategy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/units/plugins/strategy/__init__.py
diff --git a/test/units/plugins/strategy/test_linear.py b/test/units/plugins/strategy/test_linear.py
new file mode 100644
index 0000000..b39c142
--- /dev/null
+++ b/test/units/plugins/strategy/test_linear.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2018 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+
+from units.compat import unittest
+from unittest.mock import patch, MagicMock
+
+from ansible.executor.play_iterator import PlayIterator
+from ansible.playbook import Playbook
+from ansible.playbook.play_context import PlayContext
+from ansible.plugins.strategy.linear import StrategyModule
+from ansible.executor.task_queue_manager import TaskQueueManager
+
+from units.mock.loader import DictDataLoader
+from units.mock.path import mock_unfrackpath_noop
+
+
+class TestStrategyLinear(unittest.TestCase):
+
+ @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop)
+ def test_noop(self):
+ fake_loader = DictDataLoader({
+ "test_play.yml": """
+ - hosts: all
+ gather_facts: no
+ tasks:
+ - block:
+ - block:
+ - name: task1
+ debug: msg='task1'
+ failed_when: inventory_hostname == 'host01'
+
+ - name: task2
+ debug: msg='task2'
+
+ rescue:
+ - name: rescue1
+ debug: msg='rescue1'
+
+ - name: rescue2
+ debug: msg='rescue2'
+ """,
+ })
+
+ mock_var_manager = MagicMock()
+ mock_var_manager._fact_cache = dict()
+ mock_var_manager.get_vars.return_value = dict()
+
+ p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager)
+
+ inventory = MagicMock()
+ inventory.hosts = {}
+ hosts = []
+ for i in range(0, 2):
+ host = MagicMock()
+ host.name = host.get_name.return_value = 'host%02d' % i
+ hosts.append(host)
+ inventory.hosts[host.name] = host
+ inventory.get_hosts.return_value = hosts
+ inventory.filter_hosts.return_value = hosts
+
+ mock_var_manager._fact_cache['host00'] = dict()
+
+ play_context = PlayContext(play=p._entries[0])
+
+ itr = PlayIterator(
+ inventory=inventory,
+ play=p._entries[0],
+ play_context=play_context,
+ variable_manager=mock_var_manager,
+ all_vars=dict(),
+ )
+
+ tqm = TaskQueueManager(
+ inventory=inventory,
+ variable_manager=mock_var_manager,
+ loader=fake_loader,
+ passwords=None,
+ forks=5,
+ )
+ tqm._initialize_processes(3)
+ strategy = StrategyModule(tqm)
+ strategy._hosts_cache = [h.name for h in hosts]
+ strategy._hosts_cache_all = [h.name for h in hosts]
+
+ # implicit meta: flush_handlers
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'meta')
+
+ # debug: task1, debug: task1
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'debug')
+ self.assertEqual(host2_task.action, 'debug')
+ self.assertEqual(host1_task.name, 'task1')
+ self.assertEqual(host2_task.name, 'task1')
+
+ # mark the second host failed
+ itr.mark_host_failed(hosts[1])
+
+ # debug: task2, meta: noop
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'debug')
+ self.assertEqual(host2_task.action, 'meta')
+ self.assertEqual(host1_task.name, 'task2')
+ self.assertEqual(host2_task.name, '')
+
+ # meta: noop, debug: rescue1
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'debug')
+ self.assertEqual(host1_task.name, '')
+ self.assertEqual(host2_task.name, 'rescue1')
+
+ # meta: noop, debug: rescue2
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'debug')
+ self.assertEqual(host1_task.name, '')
+ self.assertEqual(host2_task.name, 'rescue2')
+
+ # implicit meta: flush_handlers
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'meta')
+
+ # implicit meta: flush_handlers
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'meta')
+
+ # end of iteration
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNone(host1_task)
+ self.assertIsNone(host2_task)
+
+ def test_noop_64999(self):
+ fake_loader = DictDataLoader({
+ "test_play.yml": """
+ - hosts: all
+ gather_facts: no
+ tasks:
+ - name: block1
+ block:
+ - name: block2
+ block:
+ - name: block3
+ block:
+ - name: task1
+ debug:
+ failed_when: inventory_hostname == 'host01'
+ rescue:
+ - name: rescue1
+ debug:
+ msg: "rescue"
+ - name: after_rescue1
+ debug:
+ msg: "after_rescue1"
+ """,
+ })
+
+ mock_var_manager = MagicMock()
+ mock_var_manager._fact_cache = dict()
+ mock_var_manager.get_vars.return_value = dict()
+
+ p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager)
+
+ inventory = MagicMock()
+ inventory.hosts = {}
+ hosts = []
+ for i in range(0, 2):
+ host = MagicMock()
+ host.name = host.get_name.return_value = 'host%02d' % i
+ hosts.append(host)
+ inventory.hosts[host.name] = host
+ inventory.get_hosts.return_value = hosts
+ inventory.filter_hosts.return_value = hosts
+
+ mock_var_manager._fact_cache['host00'] = dict()
+
+ play_context = PlayContext(play=p._entries[0])
+
+ itr = PlayIterator(
+ inventory=inventory,
+ play=p._entries[0],
+ play_context=play_context,
+ variable_manager=mock_var_manager,
+ all_vars=dict(),
+ )
+
+ tqm = TaskQueueManager(
+ inventory=inventory,
+ variable_manager=mock_var_manager,
+ loader=fake_loader,
+ passwords=None,
+ forks=5,
+ )
+ tqm._initialize_processes(3)
+ strategy = StrategyModule(tqm)
+ strategy._hosts_cache = [h.name for h in hosts]
+ strategy._hosts_cache_all = [h.name for h in hosts]
+
+ # implicit meta: flush_handlers
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'meta')
+
+ # debug: task1, debug: task1
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'debug')
+ self.assertEqual(host2_task.action, 'debug')
+ self.assertEqual(host1_task.name, 'task1')
+ self.assertEqual(host2_task.name, 'task1')
+
+ # mark the second host failed
+ itr.mark_host_failed(hosts[1])
+
+ # meta: noop, debug: rescue1
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'debug')
+ self.assertEqual(host1_task.name, '')
+ self.assertEqual(host2_task.name, 'rescue1')
+
+ # debug: after_rescue1, debug: after_rescue1
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'debug')
+ self.assertEqual(host2_task.action, 'debug')
+ self.assertEqual(host1_task.name, 'after_rescue1')
+ self.assertEqual(host2_task.name, 'after_rescue1')
+
+ # implicit meta: flush_handlers
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'meta')
+
+ # implicit meta: flush_handlers
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNotNone(host1_task)
+ self.assertIsNotNone(host2_task)
+ self.assertEqual(host1_task.action, 'meta')
+ self.assertEqual(host2_task.action, 'meta')
+
+ # end of iteration
+ hosts_left = strategy.get_hosts_left(itr)
+ hosts_tasks = strategy._get_next_task_lockstep(hosts_left, itr)
+ host1_task = hosts_tasks[0][1]
+ host2_task = hosts_tasks[1][1]
+ self.assertIsNone(host1_task)
+ self.assertIsNone(host2_task)
diff --git a/test/units/plugins/strategy/test_strategy.py b/test/units/plugins/strategy/test_strategy.py
new file mode 100644
index 0000000..f935f4b
--- /dev/null
+++ b/test/units/plugins/strategy/test_strategy.py
@@ -0,0 +1,492 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from units.mock.loader import DictDataLoader
+import uuid
+
+from units.compat import unittest
+from unittest.mock import patch, MagicMock
+from ansible.executor.process.worker import WorkerProcess
+from ansible.executor.task_queue_manager import TaskQueueManager
+from ansible.executor.task_result import TaskResult
+from ansible.inventory.host import Host
+from ansible.module_utils.six.moves import queue as Queue
+from ansible.playbook.block import Block
+from ansible.playbook.handler import Handler
+from ansible.plugins.strategy import StrategyBase
+
+import pytest
+
+pytestmark = pytest.mark.skipif(True, reason="Temporarily disabled due to fragile tests that need rewritten")
+
+
+class TestStrategyBase(unittest.TestCase):
+
+ def test_strategy_base_init(self):
+ queue_items = []
+
+ def _queue_empty(*args, **kwargs):
+ return len(queue_items) == 0
+
+ def _queue_get(*args, **kwargs):
+ if len(queue_items) == 0:
+ raise Queue.Empty
+ else:
+ return queue_items.pop()
+
+ def _queue_put(item, *args, **kwargs):
+ queue_items.append(item)
+
+ mock_queue = MagicMock()
+ mock_queue.empty.side_effect = _queue_empty
+ mock_queue.get.side_effect = _queue_get
+ mock_queue.put.side_effect = _queue_put
+
+ mock_tqm = MagicMock(TaskQueueManager)
+ mock_tqm._final_q = mock_queue
+ mock_tqm._workers = []
+ strategy_base = StrategyBase(tqm=mock_tqm)
+ strategy_base.cleanup()
+
+ def test_strategy_base_run(self):
+ queue_items = []
+
+ def _queue_empty(*args, **kwargs):
+ return len(queue_items) == 0
+
+ def _queue_get(*args, **kwargs):
+ if len(queue_items) == 0:
+ raise Queue.Empty
+ else:
+ return queue_items.pop()
+
+ def _queue_put(item, *args, **kwargs):
+ queue_items.append(item)
+
+ mock_queue = MagicMock()
+ mock_queue.empty.side_effect = _queue_empty
+ mock_queue.get.side_effect = _queue_get
+ mock_queue.put.side_effect = _queue_put
+
+ mock_tqm = MagicMock(TaskQueueManager)
+ mock_tqm._final_q = mock_queue
+ mock_tqm._stats = MagicMock()
+ mock_tqm.send_callback.return_value = None
+
+ for attr in ('RUN_OK', 'RUN_ERROR', 'RUN_FAILED_HOSTS', 'RUN_UNREACHABLE_HOSTS'):
+ setattr(mock_tqm, attr, getattr(TaskQueueManager, attr))
+
+ mock_iterator = MagicMock()
+ mock_iterator._play = MagicMock()
+ mock_iterator._play.handlers = []
+
+ mock_play_context = MagicMock()
+
+ mock_tqm._failed_hosts = dict()
+ mock_tqm._unreachable_hosts = dict()
+ mock_tqm._workers = []
+ strategy_base = StrategyBase(tqm=mock_tqm)
+
+ mock_host = MagicMock()
+ mock_host.name = 'host1'
+
+ self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context), mock_tqm.RUN_OK)
+ self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context, result=TaskQueueManager.RUN_ERROR), mock_tqm.RUN_ERROR)
+ mock_tqm._failed_hosts = dict(host1=True)
+ mock_iterator.get_failed_hosts.return_value = [mock_host]
+ self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context, result=False), mock_tqm.RUN_FAILED_HOSTS)
+ mock_tqm._unreachable_hosts = dict(host1=True)
+ mock_iterator.get_failed_hosts.return_value = []
+ self.assertEqual(strategy_base.run(iterator=mock_iterator, play_context=mock_play_context, result=False), mock_tqm.RUN_UNREACHABLE_HOSTS)
+ strategy_base.cleanup()
+
+ def test_strategy_base_get_hosts(self):
+ queue_items = []
+
+ def _queue_empty(*args, **kwargs):
+ return len(queue_items) == 0
+
+ def _queue_get(*args, **kwargs):
+ if len(queue_items) == 0:
+ raise Queue.Empty
+ else:
+ return queue_items.pop()
+
+ def _queue_put(item, *args, **kwargs):
+ queue_items.append(item)
+
+ mock_queue = MagicMock()
+ mock_queue.empty.side_effect = _queue_empty
+ mock_queue.get.side_effect = _queue_get
+ mock_queue.put.side_effect = _queue_put
+
+ mock_hosts = []
+ for i in range(0, 5):
+ mock_host = MagicMock()
+ mock_host.name = "host%02d" % (i + 1)
+ mock_host.has_hostkey = True
+ mock_hosts.append(mock_host)
+
+ mock_hosts_names = [h.name for h in mock_hosts]
+
+ mock_inventory = MagicMock()
+ mock_inventory.get_hosts.return_value = mock_hosts
+
+ mock_tqm = MagicMock()
+ mock_tqm._final_q = mock_queue
+ mock_tqm.get_inventory.return_value = mock_inventory
+
+ mock_play = MagicMock()
+ mock_play.hosts = ["host%02d" % (i + 1) for i in range(0, 5)]
+
+ strategy_base = StrategyBase(tqm=mock_tqm)
+ strategy_base._hosts_cache = strategy_base._hosts_cache_all = mock_hosts_names
+
+ mock_tqm._failed_hosts = []
+ mock_tqm._unreachable_hosts = []
+ self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts])
+
+ mock_tqm._failed_hosts = ["host01"]
+ self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[1:]])
+ self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0].name])
+
+ mock_tqm._unreachable_hosts = ["host02"]
+ self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[2:]])
+ strategy_base.cleanup()
+
+ @patch.object(WorkerProcess, 'run')
+ def test_strategy_base_queue_task(self, mock_worker):
+ def fake_run(self):
+ return
+
+ mock_worker.run.side_effect = fake_run
+
+ fake_loader = DictDataLoader()
+ mock_var_manager = MagicMock()
+ mock_host = MagicMock()
+ mock_host.get_vars.return_value = dict()
+ mock_host.has_hostkey = True
+ mock_inventory = MagicMock()
+ mock_inventory.get.return_value = mock_host
+
+ tqm = TaskQueueManager(
+ inventory=mock_inventory,
+ variable_manager=mock_var_manager,
+ loader=fake_loader,
+ passwords=None,
+ forks=3,
+ )
+ tqm._initialize_processes(3)
+ tqm.hostvars = dict()
+
+ mock_task = MagicMock()
+ mock_task._uuid = 'abcd'
+ mock_task.throttle = 0
+
+ try:
+ strategy_base = StrategyBase(tqm=tqm)
+ strategy_base._queue_task(host=mock_host, task=mock_task, task_vars=dict(), play_context=MagicMock())
+ self.assertEqual(strategy_base._cur_worker, 1)
+ self.assertEqual(strategy_base._pending_results, 1)
+ strategy_base._queue_task(host=mock_host, task=mock_task, task_vars=dict(), play_context=MagicMock())
+ self.assertEqual(strategy_base._cur_worker, 2)
+ self.assertEqual(strategy_base._pending_results, 2)
+ strategy_base._queue_task(host=mock_host, task=mock_task, task_vars=dict(), play_context=MagicMock())
+ self.assertEqual(strategy_base._cur_worker, 0)
+ self.assertEqual(strategy_base._pending_results, 3)
+ finally:
+ tqm.cleanup()
+
+ def test_strategy_base_process_pending_results(self):
+ mock_tqm = MagicMock()
+ mock_tqm._terminated = False
+ mock_tqm._failed_hosts = dict()
+ mock_tqm._unreachable_hosts = dict()
+ mock_tqm.send_callback.return_value = None
+
+ queue_items = []
+
+ def _queue_empty(*args, **kwargs):
+ return len(queue_items) == 0
+
+ def _queue_get(*args, **kwargs):
+ if len(queue_items) == 0:
+ raise Queue.Empty
+ else:
+ return queue_items.pop()
+
+ def _queue_put(item, *args, **kwargs):
+ queue_items.append(item)
+
+ mock_queue = MagicMock()
+ mock_queue.empty.side_effect = _queue_empty
+ mock_queue.get.side_effect = _queue_get
+ mock_queue.put.side_effect = _queue_put
+ mock_tqm._final_q = mock_queue
+
+ mock_tqm._stats = MagicMock()
+ mock_tqm._stats.increment.return_value = None
+
+ mock_play = MagicMock()
+
+ mock_host = MagicMock()
+ mock_host.name = 'test01'
+ mock_host.vars = dict()
+ mock_host.get_vars.return_value = dict()
+ mock_host.has_hostkey = True
+
+ mock_task = MagicMock()
+ mock_task._role = None
+ mock_task._parent = None
+ mock_task.ignore_errors = False
+ mock_task.ignore_unreachable = False
+ mock_task._uuid = str(uuid.uuid4())
+ mock_task.loop = None
+ mock_task.copy.return_value = mock_task
+
+ mock_handler_task = Handler()
+ mock_handler_task.name = 'test handler'
+ mock_handler_task.action = 'foo'
+ mock_handler_task._parent = None
+ mock_handler_task._uuid = 'xxxxxxxxxxxxx'
+
+ mock_iterator = MagicMock()
+ mock_iterator._play = mock_play
+ mock_iterator.mark_host_failed.return_value = None
+ mock_iterator.get_next_task_for_host.return_value = (None, None)
+
+ mock_handler_block = MagicMock()
+ mock_handler_block.name = '' # implicit unnamed block
+ mock_handler_block.block = [mock_handler_task]
+ mock_handler_block.rescue = []
+ mock_handler_block.always = []
+ mock_play.handlers = [mock_handler_block]
+
+ mock_group = MagicMock()
+ mock_group.add_host.return_value = None
+
+ def _get_host(host_name):
+ if host_name == 'test01':
+ return mock_host
+ return None
+
+ def _get_group(group_name):
+ if group_name in ('all', 'foo'):
+ return mock_group
+ return None
+
+ mock_inventory = MagicMock()
+ mock_inventory._hosts_cache = dict()
+ mock_inventory.hosts.return_value = mock_host
+ mock_inventory.get_host.side_effect = _get_host
+ mock_inventory.get_group.side_effect = _get_group
+ mock_inventory.clear_pattern_cache.return_value = None
+ mock_inventory.get_host_vars.return_value = {}
+ mock_inventory.hosts.get.return_value = mock_host
+
+ mock_var_mgr = MagicMock()
+ mock_var_mgr.set_host_variable.return_value = None
+ mock_var_mgr.set_host_facts.return_value = None
+ mock_var_mgr.get_vars.return_value = dict()
+
+ strategy_base = StrategyBase(tqm=mock_tqm)
+ strategy_base._inventory = mock_inventory
+ strategy_base._variable_manager = mock_var_mgr
+ strategy_base._blocked_hosts = dict()
+
+ def _has_dead_workers():
+ return False
+
+ strategy_base._tqm.has_dead_workers.side_effect = _has_dead_workers
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 0)
+
+ task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(changed=True))
+ queue_items.append(task_result)
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+
+ def mock_queued_task_cache():
+ return {
+ (mock_host.name, mock_task._uuid): {
+ 'task': mock_task,
+ 'host': mock_host,
+ 'task_vars': {},
+ 'play_context': {},
+ }
+ }
+
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0], task_result)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+
+ task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data='{"failed":true}')
+ queue_items.append(task_result)
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+ mock_iterator.is_failed.return_value = True
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0], task_result)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+ # self.assertIn('test01', mock_tqm._failed_hosts)
+ # del mock_tqm._failed_hosts['test01']
+ mock_iterator.is_failed.return_value = False
+
+ task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data='{"unreachable": true}')
+ queue_items.append(task_result)
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0], task_result)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+ self.assertIn('test01', mock_tqm._unreachable_hosts)
+ del mock_tqm._unreachable_hosts['test01']
+
+ task_result = TaskResult(host=mock_host.name, task=mock_task._uuid, return_data='{"skipped": true}')
+ queue_items.append(task_result)
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0], task_result)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+
+ queue_items.append(TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(add_host=dict(host_name='newhost01', new_groups=['foo']))))
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+
+ queue_items.append(TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(add_group=dict(group_name='foo'))))
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+
+ queue_items.append(TaskResult(host=mock_host.name, task=mock_task._uuid, return_data=dict(changed=True, _ansible_notify=['test handler'])))
+ strategy_base._blocked_hosts['test01'] = True
+ strategy_base._pending_results = 1
+ strategy_base._queued_task_cache = mock_queued_task_cache()
+ results = strategy_base._wait_on_pending_results(iterator=mock_iterator)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(strategy_base._pending_results, 0)
+ self.assertNotIn('test01', strategy_base._blocked_hosts)
+ self.assertEqual(mock_iterator._play.handlers[0].block[0], mock_handler_task)
+
+ # queue_items.append(('set_host_var', mock_host, mock_task, None, 'foo', 'bar'))
+ # results = strategy_base._process_pending_results(iterator=mock_iterator)
+ # self.assertEqual(len(results), 0)
+ # self.assertEqual(strategy_base._pending_results, 1)
+
+ # queue_items.append(('set_host_facts', mock_host, mock_task, None, 'foo', dict()))
+ # results = strategy_base._process_pending_results(iterator=mock_iterator)
+ # self.assertEqual(len(results), 0)
+ # self.assertEqual(strategy_base._pending_results, 1)
+
+ # queue_items.append(('bad'))
+ # self.assertRaises(AnsibleError, strategy_base._process_pending_results, iterator=mock_iterator)
+ strategy_base.cleanup()
+
+ def test_strategy_base_load_included_file(self):
+ fake_loader = DictDataLoader({
+ "test.yml": """
+ - debug: msg='foo'
+ """,
+ "bad.yml": """
+ """,
+ })
+
+ queue_items = []
+
+ def _queue_empty(*args, **kwargs):
+ return len(queue_items) == 0
+
+ def _queue_get(*args, **kwargs):
+ if len(queue_items) == 0:
+ raise Queue.Empty
+ else:
+ return queue_items.pop()
+
+ def _queue_put(item, *args, **kwargs):
+ queue_items.append(item)
+
+ mock_queue = MagicMock()
+ mock_queue.empty.side_effect = _queue_empty
+ mock_queue.get.side_effect = _queue_get
+ mock_queue.put.side_effect = _queue_put
+
+ mock_tqm = MagicMock()
+ mock_tqm._final_q = mock_queue
+
+ strategy_base = StrategyBase(tqm=mock_tqm)
+ strategy_base._loader = fake_loader
+ strategy_base.cleanup()
+
+ mock_play = MagicMock()
+
+ mock_block = MagicMock()
+ mock_block._play = mock_play
+ mock_block.vars = dict()
+
+ mock_task = MagicMock()
+ mock_task._block = mock_block
+ mock_task._role = None
+
+ # NOTE Mocking calls below to account for passing parent_block=ti_copy.build_parent_block()
+ # into load_list_of_blocks() in _load_included_file. Not doing so meant that retrieving
+ # `collection` attr from parent would result in getting MagicMock instance
+ # instead of an empty list.
+ mock_task._parent = MagicMock()
+ mock_task.copy.return_value = mock_task
+ mock_task.build_parent_block.return_value = mock_block
+ mock_block._get_parent_attribute.return_value = None
+
+ mock_iterator = MagicMock()
+ mock_iterator.mark_host_failed.return_value = None
+
+ mock_inc_file = MagicMock()
+ mock_inc_file._task = mock_task
+
+ mock_inc_file._filename = "test.yml"
+ res = strategy_base._load_included_file(included_file=mock_inc_file, iterator=mock_iterator)
+ self.assertEqual(len(res), 1)
+ self.assertTrue(isinstance(res[0], Block))
+
+ mock_inc_file._filename = "bad.yml"
+ res = strategy_base._load_included_file(included_file=mock_inc_file, iterator=mock_iterator)
+ self.assertEqual(res, [])
diff --git a/test/units/plugins/test_plugins.py b/test/units/plugins/test_plugins.py
new file mode 100644
index 0000000..be123b1
--- /dev/null
+++ b/test/units/plugins/test_plugins.py
@@ -0,0 +1,133 @@
+# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+
+# Make coding more python3-ish
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import os
+
+from units.compat import unittest
+from unittest.mock import patch, MagicMock
+from ansible.plugins.loader import PluginLoader, PluginPathContext
+
+
+class TestErrors(unittest.TestCase):
+
+ @patch.object(PluginLoader, '_get_paths')
+ def test_print_paths(self, mock_method):
+ mock_method.return_value = ['/path/one', '/path/two', '/path/three']
+ pl = PluginLoader('foo', 'foo', '', 'test_plugins')
+ paths = pl.print_paths()
+ expected_paths = os.pathsep.join(['/path/one', '/path/two', '/path/three'])
+ self.assertEqual(paths, expected_paths)
+
+ def test_plugins__get_package_paths_no_package(self):
+ pl = PluginLoader('test', '', 'test', 'test_plugin')
+ self.assertEqual(pl._get_package_paths(), [])
+
+ def test_plugins__get_package_paths_with_package(self):
+ # the _get_package_paths() call uses __import__ to load a
+ # python library, and then uses the __file__ attribute of
+ # the result for that to get the library path, so we mock
+ # that here and patch the builtin to use our mocked result
+ foo = MagicMock()
+ bar = MagicMock()
+ bam = MagicMock()
+ bam.__file__ = '/path/to/my/foo/bar/bam/__init__.py'
+ bar.bam = bam
+ foo.return_value.bar = bar
+ pl = PluginLoader('test', 'foo.bar.bam', 'test', 'test_plugin')
+ with patch('builtins.__import__', foo):
+ self.assertEqual(pl._get_package_paths(), ['/path/to/my/foo/bar/bam'])
+
+ def test_plugins__get_paths(self):
+ pl = PluginLoader('test', '', 'test', 'test_plugin')
+ pl._paths = [PluginPathContext('/path/one', False),
+ PluginPathContext('/path/two', True)]
+ self.assertEqual(pl._get_paths(), ['/path/one', '/path/two'])
+
+ # NOT YET WORKING
+ # def fake_glob(path):
+ # if path == 'test/*':
+ # return ['test/foo', 'test/bar', 'test/bam']
+ # elif path == 'test/*/*'
+ # m._paths = None
+ # mock_glob = MagicMock()
+ # mock_glob.return_value = []
+ # with patch('glob.glob', mock_glob):
+ # pass
+
+ def assertPluginLoaderConfigBecomes(self, arg, expected):
+ pl = PluginLoader('test', '', arg, 'test_plugin')
+ self.assertEqual(pl.config, expected)
+
+ def test_plugin__init_config_list(self):
+ config = ['/one', '/two']
+ self.assertPluginLoaderConfigBecomes(config, config)
+
+ def test_plugin__init_config_str(self):
+ self.assertPluginLoaderConfigBecomes('test', ['test'])
+
+ def test_plugin__init_config_none(self):
+ self.assertPluginLoaderConfigBecomes(None, [])
+
+ def test__load_module_source_no_duplicate_names(self):
+ '''
+ This test simulates importing 2 plugins with the same name,
+ and validating that the import is short circuited if a file with the same name
+ has already been imported
+ '''
+
+ fixture_path = os.path.join(os.path.dirname(__file__), 'loader_fixtures')
+
+ pl = PluginLoader('test', '', 'test', 'test_plugin')
+ one = pl._load_module_source('import_fixture', os.path.join(fixture_path, 'import_fixture.py'))
+ # This line wouldn't even succeed if we didn't short circuit on finding a duplicate name
+ two = pl._load_module_source('import_fixture', '/path/to/import_fixture.py')
+
+ self.assertEqual(one, two)
+
+ @patch('ansible.plugins.loader.glob')
+ @patch.object(PluginLoader, '_get_paths_with_context')
+ def test_all_no_duplicate_names(self, gp_mock, glob_mock):
+ '''
+ This test goes along with ``test__load_module_source_no_duplicate_names``
+ and ensures that we ignore duplicate imports on multiple paths
+ '''
+
+ fixture_path = os.path.join(os.path.dirname(__file__), 'loader_fixtures')
+
+ gp_mock.return_value = [
+ MagicMock(path=fixture_path),
+ MagicMock(path='/path/to'),
+ ]
+
+ glob_mock.glob.side_effect = [
+ [os.path.join(fixture_path, 'import_fixture.py')],
+ ['/path/to/import_fixture.py']
+ ]
+
+ pl = PluginLoader('test', '', 'test', 'test_plugins')
+ # Aside from needing ``list()`` so we can do a len, ``PluginLoader.all`` returns a generator
+ # so ``list()`` actually causes ``PluginLoader.all`` to run.
+ plugins = list(pl.all())
+ self.assertEqual(len(plugins), 1)
+
+ self.assertIn(os.path.join(fixture_path, 'import_fixture.py'), pl._module_cache)
+ self.assertNotIn('/path/to/import_fixture.py', pl._module_cache)