summaryrefslogtreecommitdiffstats
path: root/cli_helpers/utils.py
blob: 3f09cb5b8e531ca67369c0958d05b01a3d5173f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
"""Various utility functions and helpers."""

import binascii
import re
from functools import lru_cache
from typing import Dict

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from pygments.style import StyleMeta

from cli_helpers.compat import binary_type, text_type, Terminal256Formatter, StringIO


def bytes_to_string(b):
    """Convert bytes *b* to a string.

    Hexlify bytes that can't be decoded.

    """
    if isinstance(b, binary_type):
        needs_hex = False
        try:
            result = b.decode("utf8")
            needs_hex = not result.isprintable()
        except UnicodeDecodeError:
            needs_hex = True
        if needs_hex:
            return "0x" + binascii.hexlify(b).decode("ascii")
        else:
            return result
    return b


def to_string(value):
    """Convert *value* to a string."""
    if isinstance(value, binary_type):
        return bytes_to_string(value)
    else:
        return text_type(value)


def truncate_string(value, max_width=None, skip_multiline_string=True):
    """Truncate string values."""
    if skip_multiline_string and isinstance(value, text_type) and "\n" in value:
        return value
    elif (
        isinstance(value, text_type)
        and max_width is not None
        and len(value) > max_width
    ):
        return value[: max_width - 3] + "..."
    return value


def intlen(n):
    """Find the length of the integer part of a number *n*."""
    pos = n.find(".")
    return len(n) if pos < 0 else pos


def filter_dict_by_key(d, keys):
    """Filter the dict *d* to remove keys not in *keys*."""
    return {k: v for k, v in d.items() if k in keys}


def unique_items(seq):
    """Return the unique items from iterable *seq* (in order)."""
    seen = set()
    return [x for x in seq if not (x in seen or seen.add(x))]


_ansi_re = re.compile("\033\\[((?:\\d|;)*)([a-zA-Z])")


def strip_ansi(value):
    """Strip the ANSI escape sequences from a string."""
    return _ansi_re.sub("", value)


def replace(s, replace):
    """Replace multiple values in a string"""
    for r in replace:
        s = s.replace(*r)
    return s


@lru_cache()
def _get_formatter(style) -> Terminal256Formatter:
    return Terminal256Formatter(style=style)


def style_field(token, field, style):
    """Get the styled text for a *field* using *token* type."""
    formatter = _get_formatter(style)
    s = StringIO()
    formatter.format(((token, field),), s)
    return s.getvalue()


def filter_style_table(style: "StyleMeta", *relevant_styles: str) -> Dict:
    """
    get a dictionary of styles for given tokens. Typical usage:

    filter_style_table(style, 'Token.Output.EvenRow', 'Token.Output.OddRow') == {
        'Token.Output.EvenRow': "",
        'Token.Output.OddRow': "",
    }
    """
    _styles_iter = (
        (str(key), val) for key, val in getattr(style, "styles", {}).items()
    )
    _relevant_styles_iter = filter(lambda tpl: tpl[0] in relevant_styles, _styles_iter)
    return {key: val for key, val in _relevant_styles_iter}