diff options
Diffstat (limited to '')
-rw-r--r-- | gitlint/tests/base.py | 41 |
1 files changed, 40 insertions, 1 deletions
diff --git a/gitlint/tests/base.py b/gitlint/tests/base.py index add4d71..c8f68c4 100644 --- a/gitlint/tests/base.py +++ b/gitlint/tests/base.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- +import contextlib import copy import io import logging import os import re +import shutil +import tempfile try: # python 2.x @@ -21,7 +24,7 @@ except ImportError: from unittest.mock import patch # pylint: disable=no-name-in-module, import-error from gitlint.git import GitContext -from gitlint.utils import ustr, LOG_FORMAT, DEFAULT_ENCODING +from gitlint.utils import ustr, IS_PY2, LOG_FORMAT, DEFAULT_ENCODING # unittest2's assertRaisesRegex doesn't do unicode comparison. @@ -57,6 +60,15 @@ class BaseTestCase(unittest.TestCase): logging.getLogger('gitlint').propagate = False @staticmethod + @contextlib.contextmanager + def tempdir(): + tmpdir = tempfile.mkdtemp() + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + @staticmethod def get_sample_path(filename=""): # Don't join up empty files names because this will add a trailing slash if filename == "": @@ -73,6 +85,15 @@ class BaseTestCase(unittest.TestCase): return sample @staticmethod + def patch_input(side_effect): + """ Patches the built-in input() with a provided side-effect """ + module_path = "builtins.input" + if IS_PY2: + module_path = "__builtin__.raw_input" + patched_module = patch(module_path, side_effect=side_effect) + return patched_module + + @staticmethod def get_expected(filename="", variable_dict=None): """ Utility method to read an expected file from gitlint/tests/expected and return it as a string. Optionally replace template variables specified by variable_dict. """ @@ -129,6 +150,24 @@ class BaseTestCase(unittest.TestCase): return super(BaseTestCase, self).assertRaisesRegex(expected_exception, re.escape(expected_regex), *args, **kwargs) + @contextlib.contextmanager + def assertRaisesMessage(self, expected_exception, expected_msg): # pylint: disable=invalid-name + """ Asserts an exception has occurred with a given error message """ + try: + yield + except expected_exception as exc: + exception_msg = ustr(exc) + if exception_msg != expected_msg: + error = u"Right exception, wrong message:\n got: {0}\n expected: {1}" + raise self.fail(error.format(exception_msg, expected_msg)) + # else: everything is fine, just return + return + except Exception as exc: + raise self.fail(u"Expected '{0}' got '{1}'".format(expected_exception.__name__, exc.__class__.__name__)) + + # No exception raised while we expected one + raise self.fail("Expected to raise {0}, didn't get an exception at all".format(expected_exception.__name__)) + def object_equality_test(self, obj, attr_list, ctor_kwargs=None): """ Helper function to easily implement object equality tests. Creates an object clone for every passed attribute and checks for (in)equality |