diff options
Diffstat (limited to '')
-rw-r--r-- | testing/mozharness/test/test_base_script.py | 983 |
1 files changed, 983 insertions, 0 deletions
diff --git a/testing/mozharness/test/test_base_script.py b/testing/mozharness/test/test_base_script.py new file mode 100644 index 0000000000..8acb630053 --- /dev/null +++ b/testing/mozharness/test/test_base_script.py @@ -0,0 +1,983 @@ +import gc +import os +import re +import shutil +import tempfile +import types +import unittest +from unittest import mock + +PYWIN32 = False +if os.name == "nt": + try: + import win32file + + PYWIN32 = True + except ImportError: + pass + + +import mozharness.base.errors as errors +import mozharness.base.log as log +import mozharness.base.script as script +from mozharness.base.config import parse_config_file +from mozharness.base.log import CRITICAL, DEBUG, ERROR, FATAL, IGNORE, INFO, WARNING + +here = os.path.dirname(os.path.abspath(__file__)) + +test_string = """foo +bar +baz""" + + +class CleanupObj(script.ScriptMixin, log.LogMixin): + def __init__(self): + super(CleanupObj, self).__init__() + self.log_obj = None + self.config = {"log_level": ERROR} + + +def cleanup(files=None): + files = files or [] + files.extend(("test_logs", "test_dir", "tmpfile_stdout", "tmpfile_stderr")) + gc.collect() + c = CleanupObj() + for f in files: + c.rmtree(f) + + +def get_debug_script_obj(): + s = script.BaseScript( + config={"log_type": "multi", "log_level": DEBUG}, + initial_config_file="test/test.json", + ) + return s + + +def _post_fatal(self, **kwargs): + fh = open("tmpfile_stdout", "w") + print(test_string, file=fh) + fh.close() + + +# TestScript {{{1 +class TestScript(unittest.TestCase): + def setUp(self): + cleanup() + self.s = None + self.tmpdir = tempfile.mkdtemp(suffix=".mozharness") + + def tearDown(self): + # Close the logfile handles, or windows can't remove the logs + if hasattr(self, "s") and isinstance(self.s, object): + del self.s + cleanup([self.tmpdir]) + + # test _dump_config_hierarchy() when --dump-config-hierarchy is passed + def test_dump_config_hierarchy_valid_files_len(self): + try: + self.s = script.BaseScript( + initial_config_file="test/test.json", + option_args=["--cfg", "test/test_override.py,test/test_override2.py"], + config={"dump_config_hierarchy": True}, + ) + except SystemExit: + local_cfg_files = parse_config_file("test_logs/localconfigfiles.json") + # first let's see if the correct number of config files were + # realized + self.assertEqual( + len(local_cfg_files), + 4, + msg="--dump-config-hierarchy dumped wrong number of config files", + ) + + def test_dump_config_hierarchy_keys_unique_and_valid(self): + try: + self.s = script.BaseScript( + initial_config_file="test/test.json", + option_args=["--cfg", "test/test_override.py,test/test_override2.py"], + config={"dump_config_hierarchy": True}, + ) + except SystemExit: + local_cfg_files = parse_config_file("test_logs/localconfigfiles.json") + # now let's see if only unique items were added from each config + t_override = local_cfg_files.get("test/test_override.py", {}) + self.assertTrue( + t_override.get("keep_string") == "don't change me" + and len(t_override.keys()) == 1, + msg="--dump-config-hierarchy dumped wrong keys/value for " + "`test/test_override.py`. There should only be one " + "item and it should be unique to all the other " + "items in test_log/localconfigfiles.json.", + ) + + def test_dump_config_hierarchy_matches_self_config(self): + try: + ###### + # we need temp_cfg because self.s will be gcollected (NoneType) by + # the time we get to SystemExit exception + # temp_cfg will differ from self.s.config because of + # 'dump_config_hierarchy'. we have to make a deepcopy because + # config is a locked dict + temp_s = script.BaseScript( + initial_config_file="test/test.json", + option_args=["--cfg", "test/test_override.py,test/test_override2.py"], + ) + from copy import deepcopy + + temp_cfg = deepcopy(temp_s.config) + temp_cfg.update({"dump_config_hierarchy": True}) + ###### + self.s = script.BaseScript( + initial_config_file="test/test.json", + option_args=["--cfg", "test/test_override.py,test/test_override2.py"], + config={"dump_config_hierarchy": True}, + ) + except SystemExit: + local_cfg_files = parse_config_file("test_logs/localconfigfiles.json") + # finally let's just make sure that all the items added up, equals + # what we started with: self.config + target_cfg = {} + for cfg_file in local_cfg_files: + target_cfg.update(local_cfg_files[cfg_file]) + self.assertEqual( + target_cfg, + temp_cfg, + msg="all of the items (combined) in each cfg file dumped via " + "--dump-config-hierarchy does not equal self.config ", + ) + + # test _dump_config() when --dump-config is passed + def test_dump_config_equals_self_config(self): + try: + ###### + # we need temp_cfg because self.s will be gcollected (NoneType) by + # the time we get to SystemExit exception + # temp_cfg will differ from self.s.config because of + # 'dump_config_hierarchy'. we have to make a deepcopy because + # config is a locked dict + temp_s = script.BaseScript( + initial_config_file="test/test.json", + option_args=["--cfg", "test/test_override.py,test/test_override2.py"], + ) + from copy import deepcopy + + temp_cfg = deepcopy(temp_s.config) + temp_cfg.update({"dump_config": True}) + ###### + self.s = script.BaseScript( + initial_config_file="test/test.json", + option_args=["--cfg", "test/test_override.py,test/test_override2.py"], + config={"dump_config": True}, + ) + except SystemExit: + target_cfg = parse_config_file("test_logs/localconfig.json") + self.assertEqual( + target_cfg, + temp_cfg, + msg="all of the items (combined) in each cfg file dumped via " + "--dump-config does not equal self.config ", + ) + + def test_nonexistent_mkdir_p(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + self.s.mkdir_p("test_dir/foo/bar/baz") + self.assertTrue(os.path.isdir("test_dir/foo/bar/baz"), msg="mkdir_p error") + + def test_existing_mkdir_p(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + os.makedirs("test_dir/foo/bar/baz") + self.s.mkdir_p("test_dir/foo/bar/baz") + self.assertTrue( + os.path.isdir("test_dir/foo/bar/baz"), msg="mkdir_p error when dir exists" + ) + + def test_chdir(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + cwd = os.getcwd() + self.s.chdir("test_logs") + self.assertEqual(os.path.join(cwd, "test_logs"), os.getcwd(), msg="chdir error") + self.s.chdir(cwd) + + def _test_log_helper(self, obj): + obj.debug("Testing DEBUG") + obj.warning("Testing WARNING") + obj.error("Testing ERROR") + obj.critical("Testing CRITICAL") + try: + obj.fatal("Testing FATAL") + except SystemExit: + pass + else: + self.assertTrue(False, msg="fatal() didn't SystemExit!") + + def test_log(self): + self.s = get_debug_script_obj() + self.s.log_obj = None + self._test_log_helper(self.s) + del self.s + self.s = script.BaseScript(initial_config_file="test/test.json") + self._test_log_helper(self.s) + + def test_run_nonexistent_command(self): + self.s = get_debug_script_obj() + self.s.run_command( + command="this_cmd_should_not_exist --help", + env={"GARBLE": "FARG"}, + error_list=errors.PythonErrorList, + ) + error_logsize = os.path.getsize("test_logs/test_info.log") + self.assertTrue(error_logsize > 0, msg="command not found error not hit") + + def test_run_command_in_bad_dir(self): + self.s = get_debug_script_obj() + self.s.run_command( + command="ls", + cwd="/this_dir_should_not_exist", + error_list=errors.PythonErrorList, + ) + error_logsize = os.path.getsize("test_logs/test_error.log") + self.assertTrue(error_logsize > 0, msg="bad dir error not hit") + + def test_get_output_from_command_in_bad_dir(self): + self.s = get_debug_script_obj() + self.s.get_output_from_command(command="ls", cwd="/this_dir_should_not_exist") + error_logsize = os.path.getsize("test_logs/test_error.log") + self.assertTrue(error_logsize > 0, msg="bad dir error not hit") + + def test_get_output_from_command_with_missing_file(self): + self.s = get_debug_script_obj() + self.s.get_output_from_command(command="ls /this_file_should_not_exist") + error_logsize = os.path.getsize("test_logs/test_error.log") + self.assertTrue(error_logsize > 0, msg="bad file error not hit") + + def test_get_output_from_command_with_missing_file2(self): + self.s = get_debug_script_obj() + self.s.run_command( + command="cat mozharness/base/errors.py", + error_list=[ + {"substr": "error", "level": ERROR}, + { + "regex": re.compile(",$"), + "level": IGNORE, + }, + { + "substr": "]$", + "level": WARNING, + }, + ], + ) + error_logsize = os.path.getsize("test_logs/test_error.log") + self.assertTrue(error_logsize > 0, msg="error list not working properly") + + def test_download_unpack(self): + # NOTE: The action is called *download*, however, it can work for files in disk + self.s = get_debug_script_obj() + + archives_path = os.path.join(here, "helper_files", "archives") + + # Test basic decompression + for archive in ( + "archive.tar", + "archive.tar.bz2", + "archive.tar.gz", + "archive.zip", + ): + self.s.download_unpack( + url=os.path.join(archives_path, archive), extract_to=self.tmpdir + ) + self.assertIn("script.sh", os.listdir(os.path.join(self.tmpdir, "bin"))) + self.assertIn("lorem.txt", os.listdir(self.tmpdir)) + shutil.rmtree(self.tmpdir) + + # Test permissions for extracted entries from zip archive + self.s.download_unpack( + url=os.path.join(archives_path, "archive.zip"), + extract_to=self.tmpdir, + ) + file_stats = os.stat(os.path.join(self.tmpdir, "bin", "script.sh")) + orig_fstats = os.stat( + os.path.join(archives_path, "reference", "bin", "script.sh") + ) + self.assertEqual(file_stats.st_mode, orig_fstats.st_mode) + shutil.rmtree(self.tmpdir) + + # Test unzip specific dirs only + self.s.download_unpack( + url=os.path.join(archives_path, "archive.zip"), + extract_to=self.tmpdir, + extract_dirs=["bin/*"], + ) + self.assertIn("bin", os.listdir(self.tmpdir)) + self.assertNotIn("lorem.txt", os.listdir(self.tmpdir)) + shutil.rmtree(self.tmpdir) + + # Test for invalid filenames (Windows only) + if PYWIN32: + with self.assertRaises(IOError): + self.s.download_unpack( + url=os.path.join(archives_path, "archive_invalid_filename.zip"), + extract_to=self.tmpdir, + ) + + for archive in ( + "archive-setuid.tar", + "archive-escape.tar", + "archive-link.tar", + "archive-link-abs.tar", + "archive-double-link.tar", + ): + with self.assertRaises(Exception): + self.s.download_unpack( + url=os.path.join(archives_path, archive), + extract_to=self.tmpdir, + ) + + def test_unpack(self): + self.s = get_debug_script_obj() + + archives_path = os.path.join(here, "helper_files", "archives") + + # Test basic decompression + for archive in ( + "archive.tar", + "archive.tar.bz2", + "archive.tar.gz", + "archive.zip", + ): + self.s.unpack(os.path.join(archives_path, archive), self.tmpdir) + self.assertIn("script.sh", os.listdir(os.path.join(self.tmpdir, "bin"))) + self.assertIn("lorem.txt", os.listdir(self.tmpdir)) + shutil.rmtree(self.tmpdir) + + # Test permissions for extracted entries from zip archive + self.s.unpack(os.path.join(archives_path, "archive.zip"), self.tmpdir) + file_stats = os.stat(os.path.join(self.tmpdir, "bin", "script.sh")) + orig_fstats = os.stat( + os.path.join(archives_path, "reference", "bin", "script.sh") + ) + self.assertEqual(file_stats.st_mode, orig_fstats.st_mode) + shutil.rmtree(self.tmpdir) + + # Test extract specific dirs only + self.s.unpack( + os.path.join(archives_path, "archive.zip"), + self.tmpdir, + extract_dirs=["bin/*"], + ) + self.assertIn("bin", os.listdir(self.tmpdir)) + self.assertNotIn("lorem.txt", os.listdir(self.tmpdir)) + shutil.rmtree(self.tmpdir) + + # Test for invalid filenames (Windows only) + if PYWIN32: + with self.assertRaises(IOError): + self.s.unpack( + os.path.join(archives_path, "archive_invalid_filename.zip"), + self.tmpdir, + ) + + for archive in ( + "archive-setuid.tar", + "archive-escape.tar", + "archive-link.tar", + "archive-link-abs.tar", + "archive-double-link.tar", + ): + with self.assertRaises(Exception): + self.s.unpack(os.path.join(archives_path, archive), self.tmpdir) + + +# TestHelperFunctions {{{1 +class TestHelperFunctions(unittest.TestCase): + temp_file = "test_dir/mozilla" + + def setUp(self): + cleanup() + self.s = None + + def tearDown(self): + # Close the logfile handles, or windows can't remove the logs + if hasattr(self, "s") and isinstance(self.s, object): + del self.s + cleanup() + + def _create_temp_file(self, contents=test_string): + os.mkdir("test_dir") + fh = open(self.temp_file, "w+") + fh.write(contents) + fh.close + + def test_mkdir_p(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + self.s.mkdir_p("test_dir") + self.assertTrue(os.path.isdir("test_dir"), msg="mkdir_p error") + + def test_get_output_from_command(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + contents = self.s.get_output_from_command( + ["bash", "-c", "cat %s" % self.temp_file] + ) + self.assertEqual( + test_string, + contents, + msg="get_output_from_command('cat file') differs from fh.write", + ) + + def test_run_command(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + temp_file_name = os.path.basename(self.temp_file) + self.assertEqual( + self.s.run_command("cat %s" % temp_file_name, cwd="test_dir"), + 0, + msg="run_command('cat file') did not exit 0", + ) + + def test_move1(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + temp_file2 = "%s2" % self.temp_file + self.s.move(self.temp_file, temp_file2) + self.assertFalse( + os.path.exists(self.temp_file), + msg="%s still exists after move()" % self.temp_file, + ) + + def test_move2(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + temp_file2 = "%s2" % self.temp_file + self.s.move(self.temp_file, temp_file2) + self.assertTrue( + os.path.exists(temp_file2), msg="%s doesn't exist after move()" % temp_file2 + ) + + def test_copyfile(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + temp_file2 = "%s2" % self.temp_file + self.s.copyfile(self.temp_file, temp_file2) + self.assertEqual( + os.path.getsize(self.temp_file), + os.path.getsize(temp_file2), + msg="%s and %s are different sizes after copyfile()" + % (self.temp_file, temp_file2), + ) + + def test_existing_rmtree(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + self.s.mkdir_p("test_dir/foo/bar/baz") + self.s.rmtree("test_dir") + self.assertFalse(os.path.exists("test_dir"), msg="rmtree unsuccessful") + + def test_nonexistent_rmtree(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + status = self.s.rmtree("test_dir") + self.assertFalse(status, msg="nonexistent rmtree error") + + @unittest.skipUnless(PYWIN32, "PyWin32 specific") + def test_long_dir_rmtree(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + # create a very long path that the command-prompt cannot delete + # by using unicode format (max path length 32000) + path = u"\\\\?\\%s\\test_dir" % os.getcwd() + win32file.CreateDirectoryExW(u".", path) + + for x in range(0, 20): + print("path=%s" % path) + path = path + u"\\%sxxxxxxxxxxxxxxxxxxxx" % x + win32file.CreateDirectoryExW(u".", path) + self.s.rmtree("test_dir") + self.assertFalse(os.path.exists("test_dir"), msg="rmtree unsuccessful") + + @unittest.skipUnless(PYWIN32, "PyWin32 specific") + def test_chmod_rmtree(self): + self._create_temp_file() + win32file.SetFileAttributesW(self.temp_file, win32file.FILE_ATTRIBUTE_READONLY) + self.s = script.BaseScript(initial_config_file="test/test.json") + self.s.rmtree("test_dir") + self.assertFalse(os.path.exists("test_dir"), msg="rmtree unsuccessful") + + @unittest.skipIf(os.name == "nt", "Not for Windows") + def test_chmod(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + self.s.chmod(self.temp_file, 0o100700) + self.assertEqual(os.stat(self.temp_file)[0], 33216, msg="chmod unsuccessful") + + def test_env_normal(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + script_env = self.s.query_env() + self.assertEqual( + script_env, + os.environ, + msg="query_env() != env\n%s\n%s" % (script_env, os.environ), + ) + + def test_env_normal2(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + self.s.query_env() + script_env = self.s.query_env() + self.assertEqual( + script_env, + os.environ, + msg="Second query_env() != env\n%s\n%s" % (script_env, os.environ), + ) + + def test_env_partial(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + script_env = self.s.query_env(partial_env={"foo": "bar"}) + self.assertTrue("foo" in script_env and script_env["foo"] == "bar") + + def test_env_path(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + partial_path = "yaddayadda:%(PATH)s" + full_path = partial_path % {"PATH": os.environ["PATH"]} + script_env = self.s.query_env(partial_env={"PATH": partial_path}) + self.assertEqual(script_env["PATH"], full_path) + + def test_query_exe(self): + self.s = script.BaseScript( + initial_config_file="test/test.json", + config={"exes": {"foo": "bar"}}, + ) + path = self.s.query_exe("foo") + self.assertEqual(path, "bar") + + def test_query_exe_string_replacement(self): + self.s = script.BaseScript( + initial_config_file="test/test.json", + config={ + "base_work_dir": "foo", + "work_dir": "bar", + "exes": {"foo": os.path.join("%(abs_work_dir)s", "baz")}, + }, + ) + path = self.s.query_exe("foo") + self.assertEqual(path, os.path.join("foo", "bar", "baz")) + + def test_read_from_file(self): + self._create_temp_file() + self.s = script.BaseScript(initial_config_file="test/test.json") + contents = self.s.read_from_file(self.temp_file) + self.assertEqual(contents, test_string) + + def test_read_from_nonexistent_file(self): + self.s = script.BaseScript(initial_config_file="test/test.json") + contents = self.s.read_from_file("nonexistent_file!!!") + self.assertEqual(contents, None) + + +# TestScriptLogging {{{1 +class TestScriptLogging(unittest.TestCase): + # I need a log watcher helper function, here and in test_log. + def setUp(self): + cleanup() + self.s = None + + def tearDown(self): + # Close the logfile handles, or windows can't remove the logs + if hasattr(self, "s") and isinstance(self.s, object): + del self.s + cleanup() + + def test_info_logsize(self): + self.s = script.BaseScript( + config={"log_type": "multi"}, initial_config_file="test/test.json" + ) + info_logsize = os.path.getsize("test_logs/test_info.log") + self.assertTrue(info_logsize > 0, msg="initial info logfile missing/size 0") + + def test_add_summary_info(self): + self.s = script.BaseScript( + config={"log_type": "multi"}, initial_config_file="test/test.json" + ) + info_logsize = os.path.getsize("test_logs/test_info.log") + self.s.add_summary("one") + info_logsize2 = os.path.getsize("test_logs/test_info.log") + self.assertTrue( + info_logsize < info_logsize2, msg="add_summary() info not logged" + ) + + def test_add_summary_warning(self): + self.s = script.BaseScript( + config={"log_type": "multi"}, initial_config_file="test/test.json" + ) + warning_logsize = os.path.getsize("test_logs/test_warning.log") + self.s.add_summary("two", level=WARNING) + warning_logsize2 = os.path.getsize("test_logs/test_warning.log") + self.assertTrue( + warning_logsize < warning_logsize2, + msg="add_summary(level=%s) not logged in warning log" % WARNING, + ) + + def test_summary(self): + self.s = script.BaseScript( + config={"log_type": "multi"}, initial_config_file="test/test.json" + ) + self.s.add_summary("one") + self.s.add_summary("two", level=WARNING) + info_logsize = os.path.getsize("test_logs/test_info.log") + warning_logsize = os.path.getsize("test_logs/test_warning.log") + self.s.summary() + info_logsize2 = os.path.getsize("test_logs/test_info.log") + warning_logsize2 = os.path.getsize("test_logs/test_warning.log") + msg = "" + if info_logsize >= info_logsize2: + msg += "summary() didn't log to info!\n" + if warning_logsize >= warning_logsize2: + msg += "summary() didn't log to warning!\n" + self.assertEqual(msg, "", msg=msg) + + def _test_log_level(self, log_level, log_level_file_list): + self.s = script.BaseScript( + config={"log_type": "multi"}, initial_config_file="test/test.json" + ) + if log_level != FATAL: + self.s.log("testing", level=log_level) + else: + self.s._post_fatal = types.MethodType(_post_fatal, self.s) + try: + self.s.fatal("testing") + except SystemExit: + contents = None + if os.path.exists("tmpfile_stdout"): + fh = open("tmpfile_stdout") + contents = fh.read() + fh.close() + self.assertEqual(contents.rstrip(), test_string, "_post_fatal failed!") + del self.s + msg = "" + for level in log_level_file_list: + log_path = "test_logs/test_%s.log" % level + if not os.path.exists(log_path): + msg += "%s doesn't exist!\n" % log_path + else: + filesize = os.path.getsize(log_path) + if not filesize > 0: + msg += "%s is size 0!\n" % log_path + self.assertEqual(msg, "", msg=msg) + + def test_debug(self): + self._test_log_level(DEBUG, []) + + def test_ignore(self): + self._test_log_level(IGNORE, []) + + def test_info(self): + self._test_log_level(INFO, [INFO]) + + def test_warning(self): + self._test_log_level(WARNING, [INFO, WARNING]) + + def test_error(self): + self._test_log_level(ERROR, [INFO, WARNING, ERROR]) + + def test_critical(self): + self._test_log_level(CRITICAL, [INFO, WARNING, ERROR, CRITICAL]) + + def test_fatal(self): + self._test_log_level(FATAL, [INFO, WARNING, ERROR, CRITICAL, FATAL]) + + +# TestRetry {{{1 +class NewError(Exception): + pass + + +class OtherError(Exception): + pass + + +class TestRetry(unittest.TestCase): + def setUp(self): + self.ATTEMPT_N = 1 + self.s = script.BaseScript(initial_config_file="test/test.json") + + def tearDown(self): + # Close the logfile handles, or windows can't remove the logs + if hasattr(self, "s") and isinstance(self.s, object): + del self.s + cleanup() + + def _succeedOnSecondAttempt(self, foo=None, exception=Exception): + if self.ATTEMPT_N == 2: + self.ATTEMPT_N += 1 + return + self.ATTEMPT_N += 1 + raise exception("Fail") + + def _raiseCustomException(self): + return self._succeedOnSecondAttempt(exception=NewError) + + def _alwaysPass(self): + self.ATTEMPT_N += 1 + return True + + def _mirrorArgs(self, *args, **kwargs): + return args, kwargs + + def _alwaysFail(self): + raise Exception("Fail") + + def testRetrySucceed(self): + # Will raise if anything goes wrong + self.s.retry(self._succeedOnSecondAttempt, attempts=2, sleeptime=0) + + def testRetryFailWithoutCatching(self): + self.assertRaises( + Exception, self.s.retry, self._alwaysFail, sleeptime=0, exceptions=() + ) + + def testRetryFailEnsureRaisesLastException(self): + self.assertRaises( + SystemExit, self.s.retry, self._alwaysFail, sleeptime=0, error_level=FATAL + ) + + def testRetrySelectiveExceptionSucceed(self): + self.s.retry( + self._raiseCustomException, + attempts=2, + sleeptime=0, + retry_exceptions=(NewError,), + ) + + def testRetrySelectiveExceptionFail(self): + self.assertRaises( + NewError, + self.s.retry, + self._raiseCustomException, + attempts=2, + sleeptime=0, + retry_exceptions=(OtherError,), + ) + + # TODO: figure out a way to test that the sleep actually happened + def testRetryWithSleep(self): + self.s.retry(self._succeedOnSecondAttempt, attempts=2, sleeptime=1) + + def testRetryOnlyRunOnce(self): + """Tests that retry() doesn't call the action again after success""" + self.s.retry(self._alwaysPass, attempts=3, sleeptime=0) + # self.ATTEMPT_N gets increased regardless of pass/fail + self.assertEqual(2, self.ATTEMPT_N) + + def testRetryReturns(self): + ret = self.s.retry(self._alwaysPass, sleeptime=0) + self.assertEqual(ret, True) + + def testRetryCleanupIsCalled(self): + cleanup = mock.Mock() + self.s.retry(self._succeedOnSecondAttempt, cleanup=cleanup, sleeptime=0) + self.assertEqual(cleanup.call_count, 1) + + def testRetryArgsPassed(self): + args = (1, "two", 3) + kwargs = dict(foo="a", bar=7) + ret = self.s.retry( + self._mirrorArgs, args=args, kwargs=kwargs.copy(), sleeptime=0 + ) + print(ret) + self.assertEqual(ret[0], args) + self.assertEqual(ret[1], kwargs) + + +class BaseScriptWithDecorators(script.BaseScript): + def __init__(self, *args, **kwargs): + super(BaseScriptWithDecorators, self).__init__(*args, **kwargs) + + self.pre_run_1_args = [] + self.raise_during_pre_run_1 = False + self.pre_action_1_args = [] + self.raise_during_pre_action_1 = False + self.pre_action_2_args = [] + self.pre_action_3_args = [] + self.post_action_1_args = [] + self.raise_during_post_action_1 = False + self.post_action_2_args = [] + self.post_action_3_args = [] + self.post_run_1_args = [] + self.raise_during_post_run_1 = False + self.post_run_2_args = [] + self.raise_during_build = False + + @script.PreScriptRun + def pre_run_1(self, *args, **kwargs): + self.pre_run_1_args.append((args, kwargs)) + + if self.raise_during_pre_run_1: + raise Exception(self.raise_during_pre_run_1) + + @script.PreScriptAction + def pre_action_1(self, *args, **kwargs): + self.pre_action_1_args.append((args, kwargs)) + + if self.raise_during_pre_action_1: + raise Exception(self.raise_during_pre_action_1) + + @script.PreScriptAction + def pre_action_2(self, *args, **kwargs): + self.pre_action_2_args.append((args, kwargs)) + + @script.PreScriptAction("clobber") + def pre_action_3(self, *args, **kwargs): + self.pre_action_3_args.append((args, kwargs)) + + @script.PostScriptAction + def post_action_1(self, *args, **kwargs): + self.post_action_1_args.append((args, kwargs)) + + if self.raise_during_post_action_1: + raise Exception(self.raise_during_post_action_1) + + @script.PostScriptAction + def post_action_2(self, *args, **kwargs): + self.post_action_2_args.append((args, kwargs)) + + @script.PostScriptAction("build") + def post_action_3(self, *args, **kwargs): + self.post_action_3_args.append((args, kwargs)) + + @script.PostScriptRun + def post_run_1(self, *args, **kwargs): + self.post_run_1_args.append((args, kwargs)) + + if self.raise_during_post_run_1: + raise Exception(self.raise_during_post_run_1) + + @script.PostScriptRun + def post_run_2(self, *args, **kwargs): + self.post_run_2_args.append((args, kwargs)) + + def build(self): + if self.raise_during_build: + raise Exception(self.raise_during_build) + + +class TestScriptDecorators(unittest.TestCase): + def setUp(self): + cleanup() + self.s = None + + def tearDown(self): + if hasattr(self, "s") and isinstance(self.s, object): + del self.s + + cleanup() + + def test_decorators_registered(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + + self.assertEqual(len(self.s._listeners["pre_run"]), 1) + self.assertEqual(len(self.s._listeners["pre_action"]), 3) + self.assertEqual(len(self.s._listeners["post_action"]), 3) + self.assertEqual(len(self.s._listeners["post_run"]), 2) + + def test_pre_post_fired(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + self.s.run() + + self.assertEqual(len(self.s.pre_run_1_args), 1) + self.assertEqual(len(self.s.pre_action_1_args), 2) + self.assertEqual(len(self.s.pre_action_2_args), 2) + self.assertEqual(len(self.s.pre_action_3_args), 1) + self.assertEqual(len(self.s.post_action_1_args), 2) + self.assertEqual(len(self.s.post_action_2_args), 2) + self.assertEqual(len(self.s.post_action_3_args), 1) + self.assertEqual(len(self.s.post_run_1_args), 1) + + self.assertEqual(self.s.pre_run_1_args[0], ((), {})) + + self.assertEqual(self.s.pre_action_1_args[0], (("clobber",), {})) + self.assertEqual(self.s.pre_action_1_args[1], (("build",), {})) + + # pre_action_3 should only get called for the action it is registered + # with. + self.assertEqual(self.s.pre_action_3_args[0], (("clobber",), {})) + + self.assertEqual(self.s.post_action_1_args[0][0], ("clobber",)) + self.assertEqual(self.s.post_action_1_args[0][1], dict(success=True)) + self.assertEqual(self.s.post_action_1_args[1][0], ("build",)) + self.assertEqual(self.s.post_action_1_args[1][1], dict(success=True)) + + # post_action_3 should only get called for the action it is registered + # with. + self.assertEqual(self.s.post_action_3_args[0], (("build",), dict(success=True))) + + self.assertEqual(self.s.post_run_1_args[0], ((), {})) + + def test_post_always_fired(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + self.s.raise_during_build = "Testing post always fired." + + with self.assertRaises(SystemExit): + self.s.run() + + self.assertEqual(len(self.s.pre_run_1_args), 1) + self.assertEqual(len(self.s.pre_action_1_args), 2) + self.assertEqual(len(self.s.post_action_1_args), 2) + self.assertEqual(len(self.s.post_action_2_args), 2) + self.assertEqual(len(self.s.post_run_1_args), 1) + self.assertEqual(len(self.s.post_run_2_args), 1) + + self.assertEqual(self.s.post_action_1_args[0][1], dict(success=True)) + self.assertEqual(self.s.post_action_1_args[1][1], dict(success=False)) + self.assertEqual(self.s.post_action_2_args[1][1], dict(success=False)) + + def test_pre_run_exception(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + self.s.raise_during_pre_run_1 = "Error during pre run 1" + + with self.assertRaises(SystemExit): + self.s.run() + + self.assertEqual(len(self.s.pre_run_1_args), 1) + self.assertEqual(len(self.s.pre_action_1_args), 0) + self.assertEqual(len(self.s.post_run_1_args), 1) + self.assertEqual(len(self.s.post_run_2_args), 1) + + def test_pre_action_exception(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + self.s.raise_during_pre_action_1 = "Error during pre 1" + + with self.assertRaises(SystemExit): + self.s.run() + + self.assertEqual(len(self.s.pre_run_1_args), 1) + self.assertEqual(len(self.s.pre_action_1_args), 1) + self.assertEqual(len(self.s.pre_action_2_args), 0) + self.assertEqual(len(self.s.post_action_1_args), 1) + self.assertEqual(len(self.s.post_action_2_args), 1) + self.assertEqual(len(self.s.post_run_1_args), 1) + self.assertEqual(len(self.s.post_run_2_args), 1) + + def test_post_action_exception(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + self.s.raise_during_post_action_1 = "Error during post 1" + + with self.assertRaises(SystemExit): + self.s.run() + + self.assertEqual(len(self.s.pre_run_1_args), 1) + self.assertEqual(len(self.s.post_action_1_args), 1) + self.assertEqual(len(self.s.post_action_2_args), 1) + self.assertEqual(len(self.s.post_run_1_args), 1) + self.assertEqual(len(self.s.post_run_2_args), 1) + + def test_post_run_exception(self): + self.s = BaseScriptWithDecorators(initial_config_file="test/test.json") + self.s.raise_during_post_run_1 = "Error during post run 1" + + with self.assertRaises(SystemExit): + self.s.run() + + self.assertEqual(len(self.s.post_run_1_args), 1) + self.assertEqual(len(self.s.post_run_2_args), 1) + + +# main {{{1 +if __name__ == "__main__": + unittest.main() |