summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/compat.py86
-rw-r--r--tests/config_data/configrc18
-rw-r--r--tests/config_data/configspecrc20
-rw-r--r--tests/config_data/invalid_configrc18
-rw-r--r--tests/config_data/invalid_configspecrc20
-rw-r--r--tests/tabular_output/__init__.py0
-rw-r--r--tests/tabular_output/test_delimited_output_adapter.py55
-rw-r--r--tests/tabular_output/test_output_formatter.py271
-rw-r--r--tests/tabular_output/test_preprocessors.py350
-rw-r--r--tests/tabular_output/test_tabulate_adapter.py113
-rw-r--r--tests/tabular_output/test_tsv_output_adapter.py36
-rw-r--r--tests/tabular_output/test_vertical_table_adapter.py49
-rw-r--r--tests/test_cli_helpers.py5
-rw-r--r--tests/test_config.py293
-rw-r--r--tests/test_utils.py77
-rw-r--r--tests/utils.py18
17 files changed, 1429 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/compat.py b/tests/compat.py
new file mode 100644
index 0000000..dfc57f3
--- /dev/null
+++ b/tests/compat.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""Python compatibility support for CLI Helpers' tests."""
+
+from __future__ import unicode_literals
+import os as _os
+import shutil as _shutil
+import tempfile as _tempfile
+import warnings as _warnings
+
+from cli_helpers.compat import PY2
+
+
+class _TempDirectory(object):
+ """Create and return a temporary directory. This has the same
+ behavior as mkdtemp but can be used as a context manager. For
+ example:
+
+ with TemporaryDirectory() as tmpdir:
+ ...
+
+ Upon exiting the context, the directory and everything contained
+ in it are removed.
+
+ NOTE: Copied from the Python 3 standard library.
+ """
+
+ # Handle mkdtemp raising an exception
+ name = None
+ _closed = False
+
+ def __init__(self, suffix="", prefix="tmp", dir=None):
+ self.name = _tempfile.mkdtemp(suffix, prefix, dir)
+
+ def __repr__(self):
+ return "<{} {!r}>".format(self.__class__.__name__, self.name)
+
+ def __enter__(self):
+ return self.name
+
+ def cleanup(self, _warn=False, _warnings=_warnings):
+ if self.name and not self._closed:
+ try:
+ _shutil.rmtree(self.name)
+ except (TypeError, AttributeError) as ex:
+ if "None" not in "%s" % (ex,):
+ raise
+ self._rmtree(self.name)
+ self._closed = True
+ if _warn and _warnings.warn:
+ _warnings.warn(
+ "Implicitly cleaning up {!r}".format(self), ResourceWarning
+ )
+
+ def __exit__(self, exc, value, tb):
+ self.cleanup()
+
+ def __del__(self):
+ # Issue a ResourceWarning if implicit cleanup needed
+ self.cleanup(_warn=True)
+
+ def _rmtree(
+ self,
+ path,
+ _OSError=OSError,
+ _sep=_os.path.sep,
+ _listdir=_os.listdir,
+ _remove=_os.remove,
+ _rmdir=_os.rmdir,
+ ):
+ # Essentially a stripped down version of shutil.rmtree. We can't
+ # use globals because they may be None'ed out at shutdown.
+ if not isinstance(path, str):
+ _sep = _sep.encode()
+ try:
+ for name in _listdir(path):
+ fullname = path + _sep + name
+ try:
+ _remove(fullname)
+ except _OSError:
+ self._rmtree(fullname)
+ _rmdir(path)
+ except _OSError:
+ pass
+
+
+TemporaryDirectory = _TempDirectory if PY2 else _tempfile.TemporaryDirectory
diff --git a/tests/config_data/configrc b/tests/config_data/configrc
new file mode 100644
index 0000000..8050b58
--- /dev/null
+++ b/tests/config_data/configrc
@@ -0,0 +1,18 @@
+# vi: ft=dosini
+# Test file comment
+
+[section]
+# Test section comment
+
+# Test field comment
+test_boolean_default = True
+
+# Test field commented out
+# Uncomment to enable
+# test_boolean = True
+
+test_string_file = '~/myfile'
+
+test_option = 'foobar✔'
+
+[section2]
diff --git a/tests/config_data/configspecrc b/tests/config_data/configspecrc
new file mode 100644
index 0000000..afa1c6d
--- /dev/null
+++ b/tests/config_data/configspecrc
@@ -0,0 +1,20 @@
+# vi: ft=dosini
+# Test file comment
+
+[section]
+# Test section comment
+
+# Test field comment
+test_boolean_default = boolean(default=True)
+
+test_boolean = boolean()
+
+# Test field commented out
+# Uncomment to enable
+# test_boolean = True
+
+test_string_file = string(default='~/myfile')
+
+test_option = option('foo', 'bar', 'foobar', 'foobar✔', default='foobar✔')
+
+[section2]
diff --git a/tests/config_data/invalid_configrc b/tests/config_data/invalid_configrc
new file mode 100644
index 0000000..8e66190
--- /dev/null
+++ b/tests/config_data/invalid_configrc
@@ -0,0 +1,18 @@
+# vi: ft=dosini
+# Test file comment
+
+[section]
+# Test section comment
+
+# Test field comment
+test_boolean_default True
+
+# Test field commented out
+# Uncomment to enable
+# test_boolean = True
+
+test_string_file = '~/myfile'
+
+test_option = 'foobar✔'
+
+[section2]
diff --git a/tests/config_data/invalid_configspecrc b/tests/config_data/invalid_configspecrc
new file mode 100644
index 0000000..d405e52
--- /dev/null
+++ b/tests/config_data/invalid_configspecrc
@@ -0,0 +1,20 @@
+# vi: ft=dosini
+# Test file comment
+
+[section]
+# Test section comment
+
+# Test field comment
+test_boolean_default = boolean(default=True)
+
+test_boolean = bool(default=False)
+
+# Test field commented out
+# Uncomment to enable
+# test_boolean = True
+
+test_string_file = string(default='~/myfile')
+
+test_option = option('foo', 'bar', 'foobar', 'foobar✔', default='foobar✔')
+
+[section2]
diff --git a/tests/tabular_output/__init__.py b/tests/tabular_output/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/tabular_output/__init__.py
diff --git a/tests/tabular_output/test_delimited_output_adapter.py b/tests/tabular_output/test_delimited_output_adapter.py
new file mode 100644
index 0000000..86a622e
--- /dev/null
+++ b/tests/tabular_output/test_delimited_output_adapter.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+"""Test the delimited output adapter."""
+
+from __future__ import unicode_literals
+from textwrap import dedent
+
+import pytest
+
+from cli_helpers.tabular_output import delimited_output_adapter
+
+
+def test_csv_wrapper():
+ """Test the delimited output adapter."""
+ # Test comma-delimited output.
+ data = [["abc", "1"], ["d", "456"]]
+ headers = ["letters", "number"]
+ output = delimited_output_adapter.adapter(iter(data), headers, dialect="unix")
+ assert "\n".join(output) == dedent(
+ '''\
+ "letters","number"\n\
+ "abc","1"\n\
+ "d","456"'''
+ )
+
+ # Test tab-delimited output.
+ data = [["abc", "1"], ["d", "456"]]
+ headers = ["letters", "number"]
+ output = delimited_output_adapter.adapter(
+ iter(data), headers, table_format="csv-tab", dialect="unix"
+ )
+ assert "\n".join(output) == dedent(
+ '''\
+ "letters"\t"number"\n\
+ "abc"\t"1"\n\
+ "d"\t"456"'''
+ )
+
+ with pytest.raises(ValueError):
+ output = delimited_output_adapter.adapter(
+ iter(data), headers, table_format="foobar"
+ )
+ list(output)
+
+
+def test_unicode_with_csv():
+ """Test that the csv wrapper can handle non-ascii characters."""
+ data = [["观音", "1"], ["Ποσειδῶν", "456"]]
+ headers = ["letters", "number"]
+ output = delimited_output_adapter.adapter(data, headers)
+ assert "\n".join(output) == dedent(
+ """\
+ letters,number\n\
+ 观音,1\n\
+ Ποσειδῶν,456"""
+ )
diff --git a/tests/tabular_output/test_output_formatter.py b/tests/tabular_output/test_output_formatter.py
new file mode 100644
index 0000000..8e1fa92
--- /dev/null
+++ b/tests/tabular_output/test_output_formatter.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+"""Test the generic output formatter interface."""
+
+from __future__ import unicode_literals
+from decimal import Decimal
+from textwrap import dedent
+
+import pytest
+
+from cli_helpers.tabular_output import format_output, TabularOutputFormatter
+from cli_helpers.compat import binary_type, text_type
+from cli_helpers.utils import strip_ansi
+
+
+def test_tabular_output_formatter():
+ """Test the TabularOutputFormatter class."""
+ headers = ["text", "numeric"]
+ data = [
+ ["abc", Decimal(1)],
+ ["defg", Decimal("11.1")],
+ ["hi", Decimal("1.1")],
+ ["Pablo\rß\n", 0],
+ ]
+ expected = dedent(
+ """\
+ +-------+---------+
+ | text | numeric |
+ +-------+---------+
+ | abc | 1 |
+ | defg | 11.1 |
+ | hi | 1.1 |
+ | Pablo | 0 |
+ | ß | |
+ +-------+---------+"""
+ )
+
+ print(expected)
+ print(
+ "\n".join(
+ TabularOutputFormatter().format_output(
+ iter(data), headers, format_name="ascii"
+ )
+ )
+ )
+ assert expected == "\n".join(
+ TabularOutputFormatter().format_output(iter(data), headers, format_name="ascii")
+ )
+
+
+def test_tabular_output_escaped():
+ """Test the ascii_escaped output format."""
+ headers = ["text", "numeric"]
+ data = [
+ ["abc", Decimal(1)],
+ ["defg", Decimal("11.1")],
+ ["hi", Decimal("1.1")],
+ ["Pablo\rß\n", 0],
+ ]
+ expected = dedent(
+ """\
+ +------------+---------+
+ | text | numeric |
+ +------------+---------+
+ | abc | 1 |
+ | defg | 11.1 |
+ | hi | 1.1 |
+ | Pablo\\rß\\n | 0 |
+ +------------+---------+"""
+ )
+
+ print(expected)
+ print(
+ "\n".join(
+ TabularOutputFormatter().format_output(
+ iter(data), headers, format_name="ascii_escaped"
+ )
+ )
+ )
+ assert expected == "\n".join(
+ TabularOutputFormatter().format_output(
+ iter(data), headers, format_name="ascii_escaped"
+ )
+ )
+
+
+def test_tabular_format_output_wrapper():
+ """Test the format_output wrapper."""
+ data = [["1", None], ["2", "Sam"], ["3", "Joe"]]
+ headers = ["id", "name"]
+ expected = dedent(
+ """\
+ +----+------+
+ | id | name |
+ +----+------+
+ | 1 | N/A |
+ | 2 | Sam |
+ | 3 | Joe |
+ +----+------+"""
+ )
+
+ assert expected == "\n".join(
+ format_output(iter(data), headers, format_name="ascii", missing_value="N/A")
+ )
+
+
+def test_additional_preprocessors():
+ """Test that additional preprocessors are run."""
+
+ def hello_world(data, headers, **_):
+ def hello_world_data(data):
+ for row in data:
+ for i, value in enumerate(row):
+ if value == "hello":
+ row[i] = "{}, world".format(value)
+ yield row
+
+ return hello_world_data(data), headers
+
+ data = [["foo", None], ["hello!", "hello"]]
+ headers = "ab"
+
+ expected = dedent(
+ """\
+ +--------+--------------+
+ | a | b |
+ +--------+--------------+
+ | foo | hello |
+ | hello! | hello, world |
+ +--------+--------------+"""
+ )
+
+ assert expected == "\n".join(
+ TabularOutputFormatter().format_output(
+ iter(data),
+ headers,
+ format_name="ascii",
+ preprocessors=(hello_world,),
+ missing_value="hello",
+ )
+ )
+
+
+def test_format_name_attribute():
+ """Test the the format_name attribute be set and retrieved."""
+ formatter = TabularOutputFormatter(format_name="plain")
+ assert formatter.format_name == "plain"
+ formatter.format_name = "simple"
+ assert formatter.format_name == "simple"
+
+ with pytest.raises(ValueError):
+ formatter.format_name = "foobar"
+
+
+def test_headless_tabulate_format():
+ """Test that a headless formatter doesn't display headers"""
+ formatter = TabularOutputFormatter(format_name="minimal")
+ headers = ["text", "numeric"]
+ data = [["a"], ["b"], ["c"]]
+ expected = "a\nb\nc"
+ assert expected == "\n".join(
+ TabularOutputFormatter().format_output(
+ iter(data),
+ headers,
+ format_name="minimal",
+ )
+ )
+
+
+def test_unsupported_format():
+ """Test that TabularOutputFormatter rejects unknown formats."""
+ formatter = TabularOutputFormatter()
+
+ with pytest.raises(ValueError):
+ formatter.format_name = "foobar"
+
+ with pytest.raises(ValueError):
+ formatter.format_output((), (), format_name="foobar")
+
+
+def test_tabulate_ansi_escape_in_default_value():
+ """Test that ANSI escape codes work with tabulate."""
+
+ data = [["1", None], ["2", "Sam"], ["3", "Joe"]]
+ headers = ["id", "name"]
+
+ styled = format_output(
+ iter(data),
+ headers,
+ format_name="psql",
+ missing_value="\x1b[38;5;10mNULL\x1b[39m",
+ )
+ unstyled = format_output(
+ iter(data), headers, format_name="psql", missing_value="NULL"
+ )
+
+ stripped_styled = [strip_ansi(s) for s in styled]
+
+ assert list(unstyled) == stripped_styled
+
+
+def test_get_type():
+ """Test that _get_type returns the expected type."""
+ formatter = TabularOutputFormatter()
+
+ tests = (
+ (1, int),
+ (2.0, float),
+ (b"binary", binary_type),
+ ("text", text_type),
+ (None, type(None)),
+ ((), text_type),
+ )
+
+ for value, data_type in tests:
+ assert data_type is formatter._get_type(value)
+
+
+def test_provide_column_types():
+ """Test that provided column types are passed to preprocessors."""
+ expected_column_types = (bool, float)
+ data = ((1, 1.0), (0, 2))
+ headers = ("a", "b")
+
+ def preprocessor(data, headers, column_types=(), **_):
+ assert expected_column_types == column_types
+ return data, headers
+
+ format_output(
+ data,
+ headers,
+ "csv",
+ column_types=expected_column_types,
+ preprocessors=(preprocessor,),
+ )
+
+
+def test_enforce_iterable():
+ """Test that all output formatters accept iterable"""
+ formatter = TabularOutputFormatter()
+ loremipsum = (
+ "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod".split(
+ " "
+ )
+ )
+
+ for format_name in formatter.supported_formats:
+ formatter.format_name = format_name
+ try:
+ formatted = next(formatter.format_output(zip(loremipsum), ["lorem"]))
+ except TypeError:
+ assert False, "{0} doesn't return iterable".format(format_name)
+
+
+@pytest.mark.parametrize(
+ "extra_kwargs",
+ [
+ {},
+ {"style": "default"},
+ {"style": "colorful"},
+ ],
+)
+def test_all_text_type(extra_kwargs):
+ """Test the TabularOutputFormatter class."""
+ data = [[1, "", None, Decimal(2)]]
+ headers = ["col1", "col2", "col3", "col4"]
+ output_formatter = TabularOutputFormatter()
+ for format_name in output_formatter.supported_formats:
+ for row in output_formatter.format_output(
+ iter(data), headers, format_name=format_name, **extra_kwargs
+ ):
+ assert isinstance(row, text_type), "not unicode for {}".format(format_name)
diff --git a/tests/tabular_output/test_preprocessors.py b/tests/tabular_output/test_preprocessors.py
new file mode 100644
index 0000000..e428bfa
--- /dev/null
+++ b/tests/tabular_output/test_preprocessors.py
@@ -0,0 +1,350 @@
+# -*- coding: utf-8 -*-
+"""Test CLI Helpers' tabular output preprocessors."""
+
+from __future__ import unicode_literals
+from decimal import Decimal
+
+import pytest
+
+from cli_helpers.compat import HAS_PYGMENTS
+from cli_helpers.tabular_output.preprocessors import (
+ align_decimals,
+ bytes_to_string,
+ convert_to_string,
+ quote_whitespaces,
+ override_missing_value,
+ override_tab_value,
+ style_output,
+ format_numbers,
+)
+
+if HAS_PYGMENTS:
+ from pygments.style import Style
+ from pygments.token import Token
+
+import inspect
+import cli_helpers.tabular_output.preprocessors
+import types
+
+
+def test_convert_to_string():
+ """Test the convert_to_string() function."""
+ data = [[1, "John"], [2, "Jill"]]
+ headers = [0, "name"]
+ expected = ([["1", "John"], ["2", "Jill"]], ["0", "name"])
+ results = convert_to_string(data, headers)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_override_missing_values():
+ """Test the override_missing_values() function."""
+ data = [[1, None], [2, "Jill"]]
+ headers = [0, "name"]
+ expected = ([[1, "<EMPTY>"], [2, "Jill"]], [0, "name"])
+ results = override_missing_value(data, headers, missing_value="<EMPTY>")
+
+ assert expected == (list(results[0]), results[1])
+
+
+@pytest.mark.skipif(not HAS_PYGMENTS, reason="requires the Pygments library")
+def test_override_missing_value_with_style():
+ """Test that *override_missing_value()* styles output."""
+
+ class NullStyle(Style):
+ styles = {Token.Output.Null: "#0f0"}
+
+ headers = ["h1", "h2"]
+ data = [[None, "2"], ["abc", None]]
+
+ expected_headers = ["h1", "h2"]
+ expected_data = [
+ ["\x1b[38;5;10m<null>\x1b[39m", "2"],
+ ["abc", "\x1b[38;5;10m<null>\x1b[39m"],
+ ]
+ results = override_missing_value(
+ data, headers, style=NullStyle, missing_value="<null>"
+ )
+
+ assert (expected_data, expected_headers) == (list(results[0]), results[1])
+
+
+def test_override_tab_value():
+ """Test the override_tab_value() function."""
+ data = [[1, "\tJohn"], [2, "Jill"]]
+ headers = ["id", "name"]
+ expected = ([[1, " John"], [2, "Jill"]], ["id", "name"])
+ results = override_tab_value(data, headers)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_bytes_to_string():
+ """Test the bytes_to_string() function."""
+ data = [[1, "John"], [2, b"Jill"]]
+ headers = [0, "name"]
+ expected = ([[1, "John"], [2, "Jill"]], [0, "name"])
+ results = bytes_to_string(data, headers)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_align_decimals():
+ """Test the align_decimals() function."""
+ data = [[Decimal("200"), Decimal("1")], [Decimal("1.00002"), Decimal("1.0")]]
+ headers = ["num1", "num2"]
+ column_types = (float, float)
+ expected = ([["200", "1"], [" 1.00002", "1.0"]], ["num1", "num2"])
+ results = align_decimals(data, headers, column_types=column_types)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_align_decimals_empty_result():
+ """Test align_decimals() with no results."""
+ data = []
+ headers = ["num1", "num2"]
+ column_types = ()
+ expected = ([], ["num1", "num2"])
+ results = align_decimals(data, headers, column_types=column_types)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_align_decimals_non_decimals():
+ """Test align_decimals() with non-decimals."""
+ data = [[Decimal("200.000"), Decimal("1.000")], [None, None]]
+ headers = ["num1", "num2"]
+ column_types = (float, float)
+ expected = ([["200.000", "1.000"], [None, None]], ["num1", "num2"])
+ results = align_decimals(data, headers, column_types=column_types)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_quote_whitespaces():
+ """Test the quote_whitespaces() function."""
+ data = [[" before", "after "], [" both ", "none"]]
+ headers = ["h1", "h2"]
+ expected = ([["' before'", "'after '"], ["' both '", "'none'"]], ["h1", "h2"])
+ results = quote_whitespaces(data, headers)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_quote_whitespaces_empty_result():
+ """Test the quote_whitespaces() function with no results."""
+ data = []
+ headers = ["h1", "h2"]
+ expected = ([], ["h1", "h2"])
+ results = quote_whitespaces(data, headers)
+
+ assert expected == (list(results[0]), results[1])
+
+
+def test_quote_whitespaces_non_spaces():
+ """Test the quote_whitespaces() function with non-spaces."""
+ data = [["\tbefore", "after \r"], ["\n both ", "none"]]
+ headers = ["h1", "h2"]
+ expected = ([["'\tbefore'", "'after \r'"], ["'\n both '", "'none'"]], ["h1", "h2"])
+ results = quote_whitespaces(data, headers)
+
+ assert expected == (list(results[0]), results[1])
+
+
+@pytest.mark.skipif(not HAS_PYGMENTS, reason="requires the Pygments library")
+def test_style_output_no_styles():
+ """Test that *style_output()* does not style without styles."""
+ headers = ["h1", "h2"]
+ data = [["1", "2"], ["a", "b"]]
+ results = style_output(data, headers)
+
+ assert (data, headers) == (list(results[0]), results[1])
+
+
+@pytest.mark.skipif(HAS_PYGMENTS, reason="requires the Pygments library be missing")
+def test_style_output_no_pygments():
+ """Test that *style_output()* does not try to style without Pygments."""
+ headers = ["h1", "h2"]
+ data = [["1", "2"], ["a", "b"]]
+ results = style_output(data, headers)
+
+ assert (data, headers) == (list(results[0]), results[1])
+
+
+@pytest.mark.skipif(not HAS_PYGMENTS, reason="requires the Pygments library")
+def test_style_output():
+ """Test that *style_output()* styles output."""
+
+ class CliStyle(Style):
+ default_style = ""
+ styles = {
+ Token.Output.Header: "bold ansibrightred",
+ Token.Output.OddRow: "bg:#eee #111",
+ Token.Output.EvenRow: "#0f0",
+ }
+
+ headers = ["h1", "h2"]
+ data = [["观音", "2"], ["Ποσειδῶν", "b"]]
+
+ expected_headers = ["\x1b[91;01mh1\x1b[39;00m", "\x1b[91;01mh2\x1b[39;00m"]
+ expected_data = [
+ ["\x1b[38;5;233;48;5;7m观音\x1b[39;49m", "\x1b[38;5;233;48;5;7m2\x1b[39;49m"],
+ ["\x1b[38;5;10mΠοσειδῶν\x1b[39m", "\x1b[38;5;10mb\x1b[39m"],
+ ]
+ results = style_output(data, headers, style=CliStyle)
+
+ assert (expected_data, expected_headers) == (list(results[0]), results[1])
+
+
+@pytest.mark.skipif(not HAS_PYGMENTS, reason="requires the Pygments library")
+def test_style_output_with_newlines():
+ """Test that *style_output()* styles output with newlines in it."""
+
+ class CliStyle(Style):
+ default_style = ""
+ styles = {
+ Token.Output.Header: "bold ansibrightred",
+ Token.Output.OddRow: "bg:#eee #111",
+ Token.Output.EvenRow: "#0f0",
+ }
+
+ headers = ["h1", "h2"]
+ data = [["观音\nLine2", "Ποσειδῶν"]]
+
+ expected_headers = ["\x1b[91;01mh1\x1b[39;00m", "\x1b[91;01mh2\x1b[39;00m"]
+ expected_data = [
+ [
+ "\x1b[38;5;233;48;5;7m观音\x1b[39;49m\n\x1b[38;5;233;48;5;7m"
+ "Line2\x1b[39;49m",
+ "\x1b[38;5;233;48;5;7mΠοσειδῶν\x1b[39;49m",
+ ]
+ ]
+ results = style_output(data, headers, style=CliStyle)
+
+ assert (expected_data, expected_headers) == (list(results[0]), results[1])
+
+
+@pytest.mark.skipif(not HAS_PYGMENTS, reason="requires the Pygments library")
+def test_style_output_custom_tokens():
+ """Test that *style_output()* styles output with custom token names."""
+
+ class CliStyle(Style):
+ default_style = ""
+ styles = {
+ Token.Results.Headers: "bold ansibrightred",
+ Token.Results.OddRows: "bg:#eee #111",
+ Token.Results.EvenRows: "#0f0",
+ }
+
+ headers = ["h1", "h2"]
+ data = [["1", "2"], ["a", "b"]]
+
+ expected_headers = ["\x1b[91;01mh1\x1b[39;00m", "\x1b[91;01mh2\x1b[39;00m"]
+ expected_data = [
+ ["\x1b[38;5;233;48;5;7m1\x1b[39;49m", "\x1b[38;5;233;48;5;7m2\x1b[39;49m"],
+ ["\x1b[38;5;10ma\x1b[39m", "\x1b[38;5;10mb\x1b[39m"],
+ ]
+
+ output = style_output(
+ data,
+ headers,
+ style=CliStyle,
+ header_token=Token.Results.Headers,
+ odd_row_token=Token.Results.OddRows,
+ even_row_token=Token.Results.EvenRows,
+ )
+
+ assert (expected_data, expected_headers) == (list(output[0]), output[1])
+
+
+def test_format_integer():
+ """Test formatting for an INTEGER datatype."""
+ data = [[1], [1000], [1000000]]
+ headers = ["h1"]
+ result_data, result_headers = format_numbers(
+ data, headers, column_types=(int,), integer_format=",", float_format=","
+ )
+
+ expected = [["1"], ["1,000"], ["1,000,000"]]
+ assert expected == list(result_data)
+ assert headers == result_headers
+
+
+def test_format_decimal():
+ """Test formatting for a DECIMAL(12, 4) datatype."""
+ data = [[Decimal("1.0000")], [Decimal("1000.0000")], [Decimal("1000000.0000")]]
+ headers = ["h1"]
+ result_data, result_headers = format_numbers(
+ data, headers, column_types=(float,), integer_format=",", float_format=","
+ )
+
+ expected = [["1.0000"], ["1,000.0000"], ["1,000,000.0000"]]
+ assert expected == list(result_data)
+ assert headers == result_headers
+
+
+def test_format_float():
+ """Test formatting for a REAL datatype."""
+ data = [[1.0], [1000.0], [1000000.0]]
+ headers = ["h1"]
+ result_data, result_headers = format_numbers(
+ data, headers, column_types=(float,), integer_format=",", float_format=","
+ )
+ expected = [["1.0"], ["1,000.0"], ["1,000,000.0"]]
+ assert expected == list(result_data)
+ assert headers == result_headers
+
+
+def test_format_integer_only():
+ """Test that providing one format string works."""
+ data = [[1, 1.0], [1000, 1000.0], [1000000, 1000000.0]]
+ headers = ["h1", "h2"]
+ result_data, result_headers = format_numbers(
+ data, headers, column_types=(int, float), integer_format=","
+ )
+
+ expected = [["1", 1.0], ["1,000", 1000.0], ["1,000,000", 1000000.0]]
+ assert expected == list(result_data)
+ assert headers == result_headers
+
+
+def test_format_numbers_no_format_strings():
+ """Test that numbers aren't formatted without format strings."""
+ data = ((1), (1000), (1000000))
+ headers = ("h1",)
+ result_data, result_headers = format_numbers(data, headers, column_types=(int,))
+ assert list(data) == list(result_data)
+ assert headers == result_headers
+
+
+def test_format_numbers_no_column_types():
+ """Test that numbers aren't formatted without column types."""
+ data = ((1), (1000), (1000000))
+ headers = ("h1",)
+ result_data, result_headers = format_numbers(
+ data, headers, integer_format=",", float_format=","
+ )
+ assert list(data) == list(result_data)
+ assert headers == result_headers
+
+
+def test_enforce_iterable():
+ preprocessors = inspect.getmembers(
+ cli_helpers.tabular_output.preprocessors, inspect.isfunction
+ )
+ loremipsum = (
+ "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod".split(
+ " "
+ )
+ )
+ for name, preprocessor in preprocessors:
+ preprocessed = preprocessor(zip(loremipsum), ["lorem"], column_types=(str,))
+ try:
+ first = next(preprocessed[0])
+ except StopIteration:
+ assert False, "{} gives no output with iterator data".format(name)
+ except TypeError:
+ assert False, "{} doesn't return iterable".format(name)
+ if isinstance(preprocessed[1], types.GeneratorType):
+ assert False, "{} returns headers as iterator".format(name)
diff --git a/tests/tabular_output/test_tabulate_adapter.py b/tests/tabular_output/test_tabulate_adapter.py
new file mode 100644
index 0000000..6e7c7db
--- /dev/null
+++ b/tests/tabular_output/test_tabulate_adapter.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+"""Test the tabulate output adapter."""
+
+from __future__ import unicode_literals
+from textwrap import dedent
+
+import pytest
+
+from cli_helpers.compat import HAS_PYGMENTS
+from cli_helpers.tabular_output import tabulate_adapter
+
+if HAS_PYGMENTS:
+ from pygments.style import Style
+ from pygments.token import Token
+
+
+def test_tabulate_wrapper():
+ """Test the *output_formatter.tabulate_wrapper()* function."""
+ data = [["abc", 1], ["d", 456]]
+ headers = ["letters", "number"]
+ output = tabulate_adapter.adapter(iter(data), headers, table_format="psql")
+ assert "\n".join(output) == dedent(
+ """\
+ +---------+--------+
+ | letters | number |
+ |---------+--------|
+ | abc | 1 |
+ | d | 456 |
+ +---------+--------+"""
+ )
+
+ data = [["abc", 1], ["d", 456]]
+ headers = ["letters", "number"]
+ output = tabulate_adapter.adapter(iter(data), headers, table_format="psql_unicode")
+ assert "\n".join(output) == dedent(
+ """\
+ ┌─────────┬────────┐
+ │ letters │ number │
+ ├─────────┼────────┤
+ │ abc │ 1 │
+ │ d │ 456 │
+ └─────────┴────────┘"""
+ )
+
+ data = [["{1,2,3}", "{{1,2},{3,4}}", "{å,魚,текст}"], ["{}", "<null>", "{<null>}"]]
+ headers = ["bigint_array", "nested_numeric_array", "配列"]
+ output = tabulate_adapter.adapter(iter(data), headers, table_format="psql")
+ assert "\n".join(output) == dedent(
+ """\
+ +--------------+----------------------+--------------+
+ | bigint_array | nested_numeric_array | 配列 |
+ |--------------+----------------------+--------------|
+ | {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |
+ | {} | <null> | {<null>} |
+ +--------------+----------------------+--------------+"""
+ )
+
+
+def test_markup_format():
+ """Test that markup formats do not have number align or string align."""
+ data = [["abc", 1], ["d", 456]]
+ headers = ["letters", "number"]
+ output = tabulate_adapter.adapter(iter(data), headers, table_format="mediawiki")
+ assert "\n".join(output) == dedent(
+ """\
+ {| class="wikitable" style="text-align: left;"
+ |+ <!-- caption -->
+ |-
+ ! letters !! number
+ |-
+ | abc || 1
+ |-
+ | d || 456
+ |}"""
+ )
+
+
+@pytest.mark.skipif(not HAS_PYGMENTS, reason="requires the Pygments library")
+def test_style_output_table():
+ """Test that *style_output_table()* styles the output table."""
+
+ class CliStyle(Style):
+ default_style = ""
+ styles = {
+ Token.Output.TableSeparator: "ansibrightred",
+ }
+
+ headers = ["h1", "h2"]
+ data = [["观音", "2"], ["Ποσειδῶν", "b"]]
+ style_output_table = tabulate_adapter.style_output_table("psql")
+
+ style_output_table(data, headers, style=CliStyle)
+ output = tabulate_adapter.adapter(iter(data), headers, table_format="psql")
+ PLUS = "\x1b[91m+\x1b[39m"
+ MINUS = "\x1b[91m-\x1b[39m"
+ PIPE = "\x1b[91m|\x1b[39m"
+
+ expected = (
+ dedent(
+ """\
+ +----------+----+
+ | h1 | h2 |
+ |----------+----|
+ | 观音 | 2 |
+ | Ποσειδῶν | b |
+ +----------+----+"""
+ )
+ .replace("+", PLUS)
+ .replace("-", MINUS)
+ .replace("|", PIPE)
+ )
+
+ assert "\n".join(output) == expected
diff --git a/tests/tabular_output/test_tsv_output_adapter.py b/tests/tabular_output/test_tsv_output_adapter.py
new file mode 100644
index 0000000..9249d87
--- /dev/null
+++ b/tests/tabular_output/test_tsv_output_adapter.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+"""Test the tsv delimited output adapter."""
+
+from __future__ import unicode_literals
+from textwrap import dedent
+
+import pytest
+
+from cli_helpers.tabular_output import tsv_output_adapter
+
+
+def test_tsv_wrapper():
+ """Test the tsv output adapter."""
+ # Test tab-delimited output.
+ data = [["ab\r\nc", "1"], ["d", "456"]]
+ headers = ["letters", "number"]
+ output = tsv_output_adapter.adapter(iter(data), headers, table_format="tsv")
+ assert "\n".join(output) == dedent(
+ """\
+ letters\tnumber\n\
+ ab\r\\nc\t1\n\
+ d\t456"""
+ )
+
+
+def test_unicode_with_tsv():
+ """Test that the tsv wrapper can handle non-ascii characters."""
+ data = [["观音", "1"], ["Ποσειδῶν", "456"]]
+ headers = ["letters", "number"]
+ output = tsv_output_adapter.adapter(data, headers)
+ assert "\n".join(output) == dedent(
+ """\
+ letters\tnumber\n\
+ 观音\t1\n\
+ Ποσειδῶν\t456"""
+ )
diff --git a/tests/tabular_output/test_vertical_table_adapter.py b/tests/tabular_output/test_vertical_table_adapter.py
new file mode 100644
index 0000000..359d9d9
--- /dev/null
+++ b/tests/tabular_output/test_vertical_table_adapter.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+"""Test the vertical table formatter."""
+
+from textwrap import dedent
+
+from cli_helpers.compat import text_type
+from cli_helpers.tabular_output import vertical_table_adapter
+
+
+def test_vertical_table():
+ """Test the default settings for vertical_table()."""
+ results = [("hello", text_type(123)), ("world", text_type(456))]
+
+ expected = dedent(
+ """\
+ ***************************[ 1. row ]***************************
+ name | hello
+ age | 123
+ ***************************[ 2. row ]***************************
+ name | world
+ age | 456"""
+ )
+ assert expected == "\n".join(
+ vertical_table_adapter.adapter(results, ("name", "age"))
+ )
+
+
+def test_vertical_table_customized():
+ """Test customized settings for vertical_table()."""
+ results = [("john", text_type(47)), ("jill", text_type(50))]
+
+ expected = dedent(
+ """\
+ -[ PERSON 1 ]-----
+ name | john
+ age | 47
+ -[ PERSON 2 ]-----
+ name | jill
+ age | 50"""
+ )
+ assert expected == "\n".join(
+ vertical_table_adapter.adapter(
+ results,
+ ("name", "age"),
+ sep_title="PERSON {n}",
+ sep_character="-",
+ sep_length=(1, 5),
+ )
+ )
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
new file mode 100644
index 0000000..50b8c81
--- /dev/null
+++ b/tests/test_cli_helpers.py
@@ -0,0 +1,5 @@
+import cli_helpers
+
+
+def test_cli_helpers():
+ assert True
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000..131bc8c
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,293 @@
+# -*- coding: utf-8 -*-
+"""Test the cli_helpers.config module."""
+
+from __future__ import unicode_literals
+import os
+
+from unittest.mock import MagicMock
+import pytest
+
+from cli_helpers.compat import MAC, text_type, WIN
+from cli_helpers.config import (
+ Config,
+ DefaultConfigValidationError,
+ get_system_config_dirs,
+ get_user_config_dir,
+ _pathify,
+)
+from .utils import with_temp_dir
+
+APP_NAME, APP_AUTHOR = "Test", "Acme"
+TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "config_data")
+DEFAULT_CONFIG = {
+ "section": {
+ "test_boolean_default": "True",
+ "test_string_file": "~/myfile",
+ "test_option": "foobar✔",
+ },
+ "section2": {},
+}
+DEFAULT_VALID_CONFIG = {
+ "section": {
+ "test_boolean_default": True,
+ "test_string_file": "~/myfile",
+ "test_option": "foobar✔",
+ },
+ "section2": {},
+}
+
+
+def _mocked_user_config(temp_dir, *args, **kwargs):
+ config = Config(*args, **kwargs)
+ config.user_config_file = MagicMock(
+ return_value=os.path.join(temp_dir, config.filename)
+ )
+ return config
+
+
+def test_user_config_dir():
+ """Test that the config directory is a string with the app name in it."""
+ if "XDG_CONFIG_HOME" in os.environ:
+ del os.environ["XDG_CONFIG_HOME"]
+ config_dir = get_user_config_dir(APP_NAME, APP_AUTHOR)
+ assert isinstance(config_dir, text_type)
+ assert config_dir.endswith(APP_NAME) or config_dir.endswith(_pathify(APP_NAME))
+
+
+def test_sys_config_dirs():
+ """Test that the sys config directories are returned correctly."""
+ if "XDG_CONFIG_DIRS" in os.environ:
+ del os.environ["XDG_CONFIG_DIRS"]
+ config_dirs = get_system_config_dirs(APP_NAME, APP_AUTHOR)
+ assert isinstance(config_dirs, list)
+ assert config_dirs[0].endswith(APP_NAME) or config_dirs[0].endswith(
+ _pathify(APP_NAME)
+ )
+
+
+@pytest.mark.skipif(not WIN, reason="requires Windows")
+def test_windows_user_config_dir_no_roaming():
+ """Test that Windows returns the user config directory without roaming."""
+ config_dir = get_user_config_dir(APP_NAME, APP_AUTHOR, roaming=False)
+ assert isinstance(config_dir, text_type)
+ assert config_dir.endswith(APP_NAME)
+ assert "Local" in config_dir
+
+
+@pytest.mark.skipif(not MAC, reason="requires macOS")
+def test_mac_user_config_dir_no_xdg():
+ """Test that macOS returns the user config directory without XDG."""
+ config_dir = get_user_config_dir(APP_NAME, APP_AUTHOR, force_xdg=False)
+ assert isinstance(config_dir, text_type)
+ assert config_dir.endswith(APP_NAME)
+ assert "Library" in config_dir
+
+
+@pytest.mark.skipif(not MAC, reason="requires macOS")
+def test_mac_system_config_dirs_no_xdg():
+ """Test that macOS returns the system config directories without XDG."""
+ config_dirs = get_system_config_dirs(APP_NAME, APP_AUTHOR, force_xdg=False)
+ assert isinstance(config_dirs, list)
+ assert config_dirs[0].endswith(APP_NAME)
+ assert "Library" in config_dirs[0]
+
+
+def test_config_reading_raise_errors():
+ """Test that instantiating Config will raise errors when appropriate."""
+ with pytest.raises(ValueError):
+ Config(APP_NAME, APP_AUTHOR, "test_config", write_default=True)
+
+ with pytest.raises(ValueError):
+ Config(APP_NAME, APP_AUTHOR, "test_config", validate=True)
+
+ with pytest.raises(TypeError):
+ Config(APP_NAME, APP_AUTHOR, "test_config", default=b"test")
+
+
+def test_config_user_file():
+ """Test that the Config user_config_file is appropriate."""
+ config = Config(APP_NAME, APP_AUTHOR, "test_config")
+ assert get_user_config_dir(APP_NAME, APP_AUTHOR) in config.user_config_file()
+
+
+def test_config_reading_default_dict():
+ """Test that the Config constructor will read in defaults from a dict."""
+ default = {"main": {"foo": "bar"}}
+ config = Config(APP_NAME, APP_AUTHOR, "test_config", default=default)
+ assert config.data == default
+
+
+def test_config_reading_no_default():
+ """Test that the Config constructor will work without any defaults."""
+ config = Config(APP_NAME, APP_AUTHOR, "test_config")
+ assert config.data == {}
+
+
+def test_config_reading_default_file():
+ """Test that the Config will work with a default file."""
+ config = Config(
+ APP_NAME,
+ APP_AUTHOR,
+ "test_config",
+ default=os.path.join(TEST_DATA_DIR, "configrc"),
+ )
+ config.read_default_config()
+ assert config.data == DEFAULT_CONFIG
+
+
+def test_config_reading_configspec():
+ """Test that the Config default file will work with a configspec."""
+ config = Config(
+ APP_NAME,
+ APP_AUTHOR,
+ "test_config",
+ validate=True,
+ default=os.path.join(TEST_DATA_DIR, "configspecrc"),
+ )
+ config.read_default_config()
+ assert config.data == DEFAULT_VALID_CONFIG
+
+
+def test_config_reading_configspec_with_error():
+ """Test that reading an invalid configspec raises and exception."""
+ with pytest.raises(DefaultConfigValidationError):
+ config = Config(
+ APP_NAME,
+ APP_AUTHOR,
+ "test_config",
+ validate=True,
+ default=os.path.join(TEST_DATA_DIR, "invalid_configspecrc"),
+ )
+ config.read_default_config()
+
+
+@with_temp_dir
+def test_write_and_read_default_config(temp_dir=None):
+ config_file = "test_config"
+ default_file = os.path.join(TEST_DATA_DIR, "configrc")
+ temp_config_file = os.path.join(temp_dir, config_file)
+
+ config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file
+ )
+ config.read_default_config()
+ config.write_default_config()
+
+ user_config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file
+ )
+ user_config.read()
+ assert temp_config_file in user_config.config_filenames
+ assert user_config == config
+
+ with open(temp_config_file) as f:
+ contents = f.read()
+ assert "# Test file comment" in contents
+ assert "# Test section comment" in contents
+ assert "# Test field comment" in contents
+ assert "# Test field commented out" in contents
+
+
+@with_temp_dir
+def test_write_and_read_default_config_from_configspec(temp_dir=None):
+ config_file = "test_config"
+ default_file = os.path.join(TEST_DATA_DIR, "configspecrc")
+ temp_config_file = os.path.join(temp_dir, config_file)
+
+ config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file, validate=True
+ )
+ config.read_default_config()
+ config.write_default_config()
+
+ user_config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file, validate=True
+ )
+ user_config.read()
+ assert temp_config_file in user_config.config_filenames
+ assert user_config == config
+
+ with open(temp_config_file) as f:
+ contents = f.read()
+ assert "# Test file comment" in contents
+ assert "# Test section comment" in contents
+ assert "# Test field comment" in contents
+ assert "# Test field commented out" in contents
+
+
+@with_temp_dir
+def test_overwrite_default_config_from_configspec(temp_dir=None):
+ config_file = "test_config"
+ default_file = os.path.join(TEST_DATA_DIR, "configspecrc")
+ temp_config_file = os.path.join(temp_dir, config_file)
+
+ config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file, validate=True
+ )
+ config.read_default_config()
+ config.write_default_config()
+
+ with open(temp_config_file, "a") as f:
+ f.write("--APPEND--")
+
+ config.write_default_config()
+
+ with open(temp_config_file) as f:
+ assert "--APPEND--" in f.read()
+
+ config.write_default_config(overwrite=True)
+
+ with open(temp_config_file) as f:
+ assert "--APPEND--" not in f.read()
+
+
+def test_read_invalid_config_file():
+ config_file = "invalid_configrc"
+
+ config = _mocked_user_config(TEST_DATA_DIR, APP_NAME, APP_AUTHOR, config_file)
+ config.read()
+ assert "section" in config
+ assert "test_string_file" in config["section"]
+ assert "test_boolean_default" not in config["section"]
+ assert "section2" in config
+
+
+@with_temp_dir
+def test_write_to_user_config(temp_dir=None):
+ config_file = "test_config"
+ default_file = os.path.join(TEST_DATA_DIR, "configrc")
+ temp_config_file = os.path.join(temp_dir, config_file)
+
+ config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file
+ )
+ config.read_default_config()
+ config.write_default_config()
+
+ with open(temp_config_file) as f:
+ assert "test_boolean_default = True" in f.read()
+
+ config["section"]["test_boolean_default"] = False
+ config.write()
+
+ with open(temp_config_file) as f:
+ assert "test_boolean_default = False" in f.read()
+
+
+@with_temp_dir
+def test_write_to_outfile(temp_dir=None):
+ config_file = "test_config"
+ outfile = os.path.join(temp_dir, "foo")
+ default_file = os.path.join(TEST_DATA_DIR, "configrc")
+
+ config = _mocked_user_config(
+ temp_dir, APP_NAME, APP_AUTHOR, config_file, default=default_file
+ )
+ config.read_default_config()
+ config.write_default_config()
+
+ config["section"]["test_boolean_default"] = False
+ config.write(outfile=outfile)
+
+ with open(outfile) as f:
+ assert "test_boolean_default = False" in f.read()
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..ba43937
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+"""Test CLI Helpers' utility functions and helpers."""
+
+from __future__ import unicode_literals
+
+from cli_helpers import utils
+
+
+def test_bytes_to_string_hexlify():
+ """Test that bytes_to_string() hexlifies binary data."""
+ assert utils.bytes_to_string(b"\xff") == "0xff"
+
+
+def test_bytes_to_string_decode_bytes():
+ """Test that bytes_to_string() decodes bytes."""
+ assert utils.bytes_to_string(b"foobar") == "foobar"
+
+
+def test_bytes_to_string_unprintable():
+ """Test that bytes_to_string() hexlifies data that is valid unicode, but unprintable."""
+ assert utils.bytes_to_string(b"\0") == "0x00"
+ assert utils.bytes_to_string(b"\1") == "0x01"
+ assert utils.bytes_to_string(b"a\0") == "0x6100"
+
+
+def test_bytes_to_string_non_bytes():
+ """Test that bytes_to_string() returns non-bytes untouched."""
+ assert utils.bytes_to_string("abc") == "abc"
+ assert utils.bytes_to_string(1) == 1
+
+
+def test_to_string_bytes():
+ """Test that to_string() converts bytes to a string."""
+ assert utils.to_string(b"foo") == "foo"
+
+
+def test_to_string_non_bytes():
+ """Test that to_string() converts non-bytes to a string."""
+ assert utils.to_string(1) == "1"
+ assert utils.to_string(2.29) == "2.29"
+
+
+def test_truncate_string():
+ """Test string truncate preprocessor."""
+ val = "x" * 100
+ assert utils.truncate_string(val, 10) == "xxxxxxx..."
+
+ val = "x " * 100
+ assert utils.truncate_string(val, 10) == "x x x x..."
+
+ val = "x" * 100
+ assert utils.truncate_string(val) == "x" * 100
+
+ val = ["x"] * 100
+ val[20] = "\n"
+ str_val = "".join(val)
+ assert utils.truncate_string(str_val, 10, skip_multiline_string=True) == str_val
+
+
+def test_intlen_with_decimal():
+ """Test that intlen() counts correctly with a decimal place."""
+ assert utils.intlen("11.1") == 2
+ assert utils.intlen("1.1") == 1
+
+
+def test_intlen_without_decimal():
+ """Test that intlen() counts correctly without a decimal place."""
+ assert utils.intlen("11") == 2
+
+
+def test_filter_dict_by_key():
+ """Test that filter_dict_by_key() filter unwanted items."""
+ keys = ("foo", "bar")
+ d = {"foo": 1, "foobar": 2}
+ fd = utils.filter_dict_by_key(d, keys)
+ assert len(fd) == 1
+ assert all([k in keys for k in fd])
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..0088eec
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+"""Utility functions for CLI Helpers' tests."""
+
+from __future__ import unicode_literals
+from functools import wraps
+
+from .compat import TemporaryDirectory
+
+
+def with_temp_dir(f):
+ """A wrapper that creates and deletes a temporary directory."""
+
+ @wraps(f)
+ def wrapped(*args, **kwargs):
+ with TemporaryDirectory() as temp_dir:
+ return f(*args, temp_dir=temp_dir, **kwargs)
+
+ return wrapped