summaryrefslogtreecommitdiffstats
path: root/gitlint/tests/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'gitlint/tests/base.py')
-rw-r--r--gitlint/tests/base.py41
1 files changed, 40 insertions, 1 deletions
diff --git a/gitlint/tests/base.py b/gitlint/tests/base.py
index add4d71..c8f68c4 100644
--- a/gitlint/tests/base.py
+++ b/gitlint/tests/base.py
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-
+import contextlib
import copy
import io
import logging
import os
import re
+import shutil
+import tempfile
try:
# python 2.x
@@ -21,7 +24,7 @@ except ImportError:
from unittest.mock import patch # pylint: disable=no-name-in-module, import-error
from gitlint.git import GitContext
-from gitlint.utils import ustr, LOG_FORMAT, DEFAULT_ENCODING
+from gitlint.utils import ustr, IS_PY2, LOG_FORMAT, DEFAULT_ENCODING
# unittest2's assertRaisesRegex doesn't do unicode comparison.
@@ -57,6 +60,15 @@ class BaseTestCase(unittest.TestCase):
logging.getLogger('gitlint').propagate = False
@staticmethod
+ @contextlib.contextmanager
+ def tempdir():
+ tmpdir = tempfile.mkdtemp()
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir)
+
+ @staticmethod
def get_sample_path(filename=""):
# Don't join up empty files names because this will add a trailing slash
if filename == "":
@@ -73,6 +85,15 @@ class BaseTestCase(unittest.TestCase):
return sample
@staticmethod
+ def patch_input(side_effect):
+ """ Patches the built-in input() with a provided side-effect """
+ module_path = "builtins.input"
+ if IS_PY2:
+ module_path = "__builtin__.raw_input"
+ patched_module = patch(module_path, side_effect=side_effect)
+ return patched_module
+
+ @staticmethod
def get_expected(filename="", variable_dict=None):
""" Utility method to read an expected file from gitlint/tests/expected and return it as a string.
Optionally replace template variables specified by variable_dict. """
@@ -129,6 +150,24 @@ class BaseTestCase(unittest.TestCase):
return super(BaseTestCase, self).assertRaisesRegex(expected_exception, re.escape(expected_regex),
*args, **kwargs)
+ @contextlib.contextmanager
+ def assertRaisesMessage(self, expected_exception, expected_msg): # pylint: disable=invalid-name
+ """ Asserts an exception has occurred with a given error message """
+ try:
+ yield
+ except expected_exception as exc:
+ exception_msg = ustr(exc)
+ if exception_msg != expected_msg:
+ error = u"Right exception, wrong message:\n got: {0}\n expected: {1}"
+ raise self.fail(error.format(exception_msg, expected_msg))
+ # else: everything is fine, just return
+ return
+ except Exception as exc:
+ raise self.fail(u"Expected '{0}' got '{1}'".format(expected_exception.__name__, exc.__class__.__name__))
+
+ # No exception raised while we expected one
+ raise self.fail("Expected to raise {0}, didn't get an exception at all".format(expected_exception.__name__))
+
def object_equality_test(self, obj, attr_list, ctor_kwargs=None):
""" Helper function to easily implement object equality tests.
Creates an object clone for every passed attribute and checks for (in)equality