summaryrefslogtreecommitdiffstats
path: root/tests/tests_contrib_logging.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tests_contrib_logging.py')
-rw-r--r--tests/tests_contrib_logging.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/tests/tests_contrib_logging.py b/tests/tests_contrib_logging.py
new file mode 100644
index 0000000..6f675dd
--- /dev/null
+++ b/tests/tests_contrib_logging.py
@@ -0,0 +1,173 @@
+# pylint: disable=missing-module-docstring, missing-class-docstring
+# pylint: disable=missing-function-docstring, no-self-use
+from __future__ import absolute_import
+
+import logging
+import logging.handlers
+import sys
+from io import StringIO
+
+import pytest
+
+from tqdm import tqdm
+from tqdm.contrib.logging import _get_first_found_console_logging_handler
+from tqdm.contrib.logging import _TqdmLoggingHandler as TqdmLoggingHandler
+from tqdm.contrib.logging import logging_redirect_tqdm, tqdm_logging_redirect
+
+from .tests_tqdm import importorskip
+
+LOGGER = logging.getLogger(__name__)
+
+TEST_LOGGING_FORMATTER = logging.Formatter()
+
+
+class CustomTqdm(tqdm):
+ messages = []
+
+ @classmethod
+ def write(cls, s, **__): # pylint: disable=arguments-differ
+ CustomTqdm.messages.append(s)
+
+
+class ErrorRaisingTqdm(tqdm):
+ exception_class = RuntimeError
+
+ @classmethod
+ def write(cls, s, **__): # pylint: disable=arguments-differ
+ raise ErrorRaisingTqdm.exception_class('fail fast')
+
+
+class TestTqdmLoggingHandler:
+ def test_should_call_tqdm_write(self):
+ CustomTqdm.messages = []
+ logger = logging.Logger('test')
+ logger.handlers = [TqdmLoggingHandler(CustomTqdm)]
+ logger.info('test')
+ assert CustomTqdm.messages == ['test']
+
+ def test_should_call_handle_error_if_exception_was_thrown(self):
+ patch = importorskip('unittest.mock').patch
+ logger = logging.Logger('test')
+ ErrorRaisingTqdm.exception_class = RuntimeError
+ handler = TqdmLoggingHandler(ErrorRaisingTqdm)
+ logger.handlers = [handler]
+ with patch.object(handler, 'handleError') as mock:
+ logger.info('test')
+ assert mock.called
+
+ @pytest.mark.parametrize('exception_class', [
+ KeyboardInterrupt,
+ SystemExit
+ ])
+ def test_should_not_swallow_certain_exceptions(self, exception_class):
+ logger = logging.Logger('test')
+ ErrorRaisingTqdm.exception_class = exception_class
+ handler = TqdmLoggingHandler(ErrorRaisingTqdm)
+ logger.handlers = [handler]
+ with pytest.raises(exception_class):
+ logger.info('test')
+
+
+class TestGetFirstFoundConsoleLoggingHandler:
+ def test_should_return_none_for_no_handlers(self):
+ assert _get_first_found_console_logging_handler([]) is None
+
+ def test_should_return_none_without_stream_handler(self):
+ handler = logging.handlers.MemoryHandler(capacity=1)
+ assert _get_first_found_console_logging_handler([handler]) is None
+
+ def test_should_return_none_for_stream_handler_not_stdout_or_stderr(self):
+ handler = logging.StreamHandler(StringIO())
+ assert _get_first_found_console_logging_handler([handler]) is None
+
+ def test_should_return_stream_handler_if_stream_is_stdout(self):
+ handler = logging.StreamHandler(sys.stdout)
+ assert _get_first_found_console_logging_handler([handler]) == handler
+
+ def test_should_return_stream_handler_if_stream_is_stderr(self):
+ handler = logging.StreamHandler(sys.stderr)
+ assert _get_first_found_console_logging_handler([handler]) == handler
+
+
+class TestRedirectLoggingToTqdm:
+ def test_should_add_and_remove_tqdm_handler(self):
+ logger = logging.Logger('test')
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert len(logger.handlers) == 1
+ assert isinstance(logger.handlers[0], TqdmLoggingHandler)
+ assert not logger.handlers
+
+ def test_should_remove_and_restore_console_handlers(self):
+ logger = logging.Logger('test')
+ stderr_console_handler = logging.StreamHandler(sys.stderr)
+ stdout_console_handler = logging.StreamHandler(sys.stderr)
+ logger.handlers = [stderr_console_handler, stdout_console_handler]
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert len(logger.handlers) == 1
+ assert isinstance(logger.handlers[0], TqdmLoggingHandler)
+ assert logger.handlers == [stderr_console_handler, stdout_console_handler]
+
+ def test_should_inherit_console_logger_formatter(self):
+ logger = logging.Logger('test')
+ formatter = logging.Formatter('custom: %(message)s')
+ console_handler = logging.StreamHandler(sys.stderr)
+ console_handler.setFormatter(formatter)
+ logger.handlers = [console_handler]
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert logger.handlers[0].formatter == formatter
+
+ def test_should_not_remove_stream_handlers_not_for_stdout_or_stderr(self):
+ logger = logging.Logger('test')
+ stream_handler = logging.StreamHandler(StringIO())
+ logger.addHandler(stream_handler)
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert len(logger.handlers) == 2
+ assert logger.handlers[0] == stream_handler
+ assert isinstance(logger.handlers[1], TqdmLoggingHandler)
+ assert logger.handlers == [stream_handler]
+
+
+class TestTqdmWithLoggingRedirect:
+ def test_should_add_and_remove_handler_from_root_logger_by_default(self):
+ original_handlers = list(logging.root.handlers)
+ with tqdm_logging_redirect(total=1) as pbar:
+ assert isinstance(logging.root.handlers[-1], TqdmLoggingHandler)
+ LOGGER.info('test')
+ pbar.update(1)
+ assert logging.root.handlers == original_handlers
+
+ def test_should_add_and_remove_handler_from_custom_logger(self):
+ logger = logging.Logger('test')
+ with tqdm_logging_redirect(total=1, loggers=[logger]) as pbar:
+ assert len(logger.handlers) == 1
+ assert isinstance(logger.handlers[0], TqdmLoggingHandler)
+ logger.info('test')
+ pbar.update(1)
+ assert not logger.handlers
+
+ def test_should_not_fail_with_logger_without_console_handler(self):
+ logger = logging.Logger('test')
+ logger.handlers = []
+ with tqdm_logging_redirect(total=1, loggers=[logger]):
+ logger.info('test')
+ assert not logger.handlers
+
+ def test_should_format_message(self):
+ logger = logging.Logger('test')
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setFormatter(logging.Formatter(
+ r'prefix:%(message)s'
+ ))
+ logger.handlers = [console_handler]
+ CustomTqdm.messages = []
+ with tqdm_logging_redirect(loggers=[logger], tqdm_class=CustomTqdm):
+ logger.info('test')
+ assert CustomTqdm.messages == ['prefix:test']
+
+ def test_use_root_logger_by_default_and_write_to_custom_tqdm(self):
+ logger = logging.root
+ CustomTqdm.messages = []
+ with tqdm_logging_redirect(total=1, tqdm_class=CustomTqdm) as pbar:
+ assert isinstance(pbar, CustomTqdm)
+ logger.info('test')
+ assert CustomTqdm.messages == ['test']