summaryrefslogtreecommitdiffstats
path: root/gitlint-core/gitlint/tests/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'gitlint-core/gitlint/tests/base.py')
-rw-r--r--gitlint-core/gitlint/tests/base.py227
1 files changed, 227 insertions, 0 deletions
diff --git a/gitlint-core/gitlint/tests/base.py b/gitlint-core/gitlint/tests/base.py
new file mode 100644
index 0000000..3899a5f
--- /dev/null
+++ b/gitlint-core/gitlint/tests/base.py
@@ -0,0 +1,227 @@
+import contextlib
+import copy
+import logging
+import os
+import re
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+from unittest.mock import patch
+
+from gitlint.config import LintConfig
+from gitlint.deprecation import LOG as DEPRECATION_LOG
+from gitlint.deprecation import Deprecation
+from gitlint.git import GitChangedFileStats, GitContext
+from gitlint.utils import FILE_ENCODING, LOG_FORMAT
+
+EXPECTED_REGEX_STYLE_SEARCH_DEPRECATION_WARNING = (
+ "WARNING: gitlint.deprecated.regex_style_search {0} - {1}: gitlint will be switching from using "
+ "Python regex 'match' (match beginning) to 'search' (match anywhere) semantics. "
+ "Please review your {1}.regex option accordingly. "
+ "To remove this warning, set general.regex-style-search=True. More details: "
+ "https://jorisroovers.github.io/gitlint/configuration/#regex-style-search"
+)
+
+
+class BaseTestCase(unittest.TestCase):
+ """Base class of which all gitlint unit test classes are derived. Provides a number of convenience methods."""
+
+ # In case of assert failures, print the full error message
+ maxDiff = None
+
+ # Working directory in which tests in this class are executed
+ working_dir = None
+ # Originally working dir when the test was started
+ original_working_dir = None
+
+ SAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "samples")
+ EXPECTED_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "expected")
+ GITLINT_USE_SH_LIB = os.environ.get("GITLINT_USE_SH_LIB", "[NOT SET]")
+
+ @classmethod
+ def setUpClass(cls):
+ # Run tests a temporary directory to shield them from any local git config
+ cls.original_working_dir = os.getcwd()
+ cls.working_dir = tempfile.mkdtemp()
+ os.chdir(cls.working_dir)
+
+ @classmethod
+ def tearDownClass(cls):
+ # Go back to original working dir and remove our temp working dir
+ os.chdir(cls.original_working_dir)
+ shutil.rmtree(cls.working_dir)
+
+ def setUp(self):
+ self.logcapture = LogCapture()
+ self.logcapture.setFormatter(logging.Formatter(LOG_FORMAT))
+ logging.getLogger("gitlint").setLevel(logging.DEBUG)
+ logging.getLogger("gitlint").handlers = [self.logcapture]
+ DEPRECATION_LOG.handlers = [self.logcapture]
+
+ # Make sure we don't propagate anything to child loggers, we need to do this explicitly here
+ # because if you run a specific test file like test_lint.py, we won't be calling the setupLogging() method
+ # in gitlint.cli that normally takes care of this
+ # Example test where this matters (for DEPRECATION_LOG):
+ # gitlint-core/gitlint/tests/rules/test_configuration_rules.py::ConfigurationRuleTests::test_ignore_by_title
+ logging.getLogger("gitlint").propagate = False
+ DEPRECATION_LOG.propagate = False
+
+ # Make sure Deprecation has a clean config set at the start of each test.
+ # Tests that want to specifically test deprecation should override this.
+ Deprecation.config = LintConfig()
+ # Normally Deprecation only logs messages once per process.
+ # For tests we want to log every time, so we reset the warning_msgs set per test.
+ Deprecation.warning_msgs = set()
+
+ @staticmethod
+ @contextlib.contextmanager
+ def tempdir():
+ tmpdir = tempfile.mkdtemp()
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir)
+
+ @staticmethod
+ def get_sample_path(filename=""):
+ # Don't join up empty files names because this will add a trailing slash
+ if filename == "":
+ return BaseTestCase.SAMPLES_DIR
+
+ return os.path.join(BaseTestCase.SAMPLES_DIR, filename)
+
+ @staticmethod
+ def get_sample(filename=""):
+ """Read and return the contents of a file in gitlint/tests/samples"""
+ sample_path = BaseTestCase.get_sample_path(filename)
+ return Path(sample_path).read_text(encoding=FILE_ENCODING)
+
+ @staticmethod
+ def patch_input(side_effect):
+ """Patches the built-in input() with a provided side-effect"""
+ module_path = "builtins.input"
+ patched_module = patch(module_path, side_effect=side_effect)
+ return patched_module
+
+ @staticmethod
+ def get_expected(filename="", variable_dict=None):
+ """Utility method to read an expected file from gitlint/tests/expected and return it as a string.
+ Optionally replace template variables specified by variable_dict."""
+ expected_path = os.path.join(BaseTestCase.EXPECTED_DIR, filename)
+ expected = Path(expected_path).read_text(encoding=FILE_ENCODING)
+
+ if variable_dict:
+ expected = expected.format(**variable_dict)
+ return expected
+
+ @staticmethod
+ def get_user_rules_path():
+ return os.path.join(BaseTestCase.SAMPLES_DIR, "user_rules")
+
+ @staticmethod
+ def gitcontext(commit_msg_str, changed_files=None):
+ """Utility method to easily create gitcontext objects based on a given commit msg string and an optional set of
+ changed files"""
+ with patch("gitlint.git.git_commentchar") as comment_char:
+ comment_char.return_value = "#"
+ gitcontext = GitContext.from_commit_msg(commit_msg_str)
+ commit = gitcontext.commits[-1]
+ if changed_files:
+ changed_file_stats = {filename: GitChangedFileStats(filename, 8, 3) for filename in changed_files}
+ commit.changed_files_stats = changed_file_stats
+ return gitcontext
+
+ @staticmethod
+ def gitcommit(commit_msg_str, changed_files=None, **kwargs):
+ """Utility method to easily create git commit given a commit msg string and an optional set of changed files"""
+ gitcontext = BaseTestCase.gitcontext(commit_msg_str, changed_files)
+ commit = gitcontext.commits[-1]
+ for attr, value in kwargs.items():
+ setattr(commit, attr, value)
+ return commit
+
+ def assert_logged(self, expected):
+ """Asserts that the logs match an expected string or list.
+ This method knows how to compare a passed list of log lines as well as a newline concatenated string
+ of all loglines."""
+ if isinstance(expected, list):
+ self.assertListEqual(self.logcapture.messages, expected)
+ else:
+ self.assertEqual("\n".join(self.logcapture.messages), expected)
+
+ def assert_log_contains(self, line):
+ """Asserts that a certain line is in the logs"""
+ self.assertIn(line, self.logcapture.messages)
+
+ def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs):
+ """Pass-through method to unittest.TestCase.assertRaisesRegex that applies re.escape() to the passed
+ `expected_regex`. This is useful to automatically escape all file paths that might be present in the regex.
+ """
+ return super().assertRaisesRegex(expected_exception, re.escape(expected_regex), *args, **kwargs)
+
+ def clearlog(self):
+ """Clears the log capture"""
+ self.logcapture.clear()
+
+ @contextlib.contextmanager
+ def assertRaisesMessage(self, expected_exception, expected_msg):
+ """Asserts an exception has occurred with a given error message"""
+ try:
+ yield
+ except expected_exception as exc:
+ exception_msg = str(exc)
+ if exception_msg != expected_msg: # pragma: nocover
+ error = f"Right exception, wrong message:\n got: {exception_msg}\n expected: {expected_msg}"
+ raise self.fail(error) from exc
+ # else: everything is fine, just return
+ return
+ except Exception as exc: # pragma: nocover
+ raise self.fail(f"Expected '{expected_exception.__name__}' got '{exc.__class__.__name__}'") from exc
+
+ # No exception raised while we expected one
+ raise self.fail(
+ f"Expected to raise {expected_exception.__name__}, didn't get an exception at all"
+ ) # pragma: nocover
+
+ def object_equality_test(self, obj, attr_list, ctor_kwargs=None):
+ """Helper function to easily implement object equality tests.
+ Creates an object clone for every passed attribute and checks for (in)equality
+ of the original object with the clone based on those attributes' values.
+ This function assumes all attributes in `attr_list` can be passed to the ctor of `obj.__class__`.
+ """
+ if not ctor_kwargs:
+ ctor_kwargs = {}
+
+ attr_kwargs = {}
+ for attr in attr_list:
+ attr_kwargs[attr] = getattr(obj, attr)
+
+ # For every attr, clone the object and assert the clone and the original object are equal
+ # Then, change the current attr and assert objects are unequal
+ for attr in attr_list:
+ attr_kwargs_copy = copy.deepcopy(attr_kwargs)
+ attr_kwargs_copy.update(ctor_kwargs)
+ clone = obj.__class__(**attr_kwargs_copy)
+ self.assertEqual(obj, clone)
+
+ # Change attribute and assert objects are different (via both attribute set and ctor)
+ setattr(clone, attr, "föo")
+ self.assertNotEqual(obj, clone)
+ attr_kwargs_copy[attr] = "föo"
+
+ self.assertNotEqual(obj, obj.__class__(**attr_kwargs_copy))
+
+
+class LogCapture(logging.Handler):
+ """Mock logging handler used to capture any log messages during tests."""
+
+ def __init__(self, *args, **kwargs):
+ logging.Handler.__init__(self, *args, **kwargs)
+ self.messages = []
+
+ def emit(self, record):
+ self.messages.append(self.format(record))
+
+ def clear(self):
+ self.messages = []