summaryrefslogtreecommitdiffstats
path: root/iredis/utils.py
blob: b11097dc2f95db63fe676b18ebee4caf57005112 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import re
import sys
import time
import logging
from collections import namedtuple
from urllib.parse import parse_qs, unquote, urlparse

from prompt_toolkit.formatted_text import FormattedText

from iredis.exceptions import InvalidArguments


logger = logging.getLogger(__name__)

_last_timer = time.time()
_timer_counter = 0
sperator = re.compile(r"\s")
logger.debug(f"[timer] start on {_last_timer}")


def timer(title):
    global _last_timer
    global _timer_counter

    now = time.time()
    tick = now - _last_timer
    logger.debug(f"[timer{_timer_counter:2}] {tick:.8f} -> {title}")

    _last_timer = now
    _timer_counter += 1


def nativestr(x):
    return x if isinstance(x, str) else x.decode("utf-8", "replace")


def literal_bytes(b):
    if isinstance(b, bytes):
        return str(b)[2:-1]
    return b


def _valide_token(words):
    token = "".join(words).strip()
    if token:
        yield token


def strip_quote_args(s):
    """
    Given string s, split it into args.(Like bash paring)
    Handle with all quote cases.

    Raise ``InvalidArguments`` if quotes not match

    :return: args list.
    """
    word = []
    in_quote = None
    pre_back_slash = False
    for char in s:
        if in_quote:
            # close quote
            if char == in_quote:
                if not pre_back_slash:
                    yield "".join(word)
                    word = []
                    in_quote = None
                else:
                    # previous char is \ , merge with current "
                    word[-1] = char
            else:
                word.append(char)
        # not in quote
        else:
            # sperator
            if sperator.match(char):
                if word:
                    yield "".join(word)
                    word = []
            # open quotes
            elif char in ["'", '"']:
                in_quote = char
            else:
                word.append(char)
        if char == "\\" and not pre_back_slash:
            pre_back_slash = True
        else:
            pre_back_slash = False

    if word:
        yield "".join(word)
    # quote not close
    if in_quote:
        raise InvalidArguments("Invalid argument(s)")


type_convert = {"posix time": "time"}


def parse_argument_to_formatted_text(
    name, _type, is_option, style_class="bottom-toolbar"
):
    result = []
    if isinstance(name, str):
        _type = type_convert.get(_type, _type)
        result.append((f"class:{style_class}.{_type}", " " + name))
    elif isinstance(name, list):
        for inner_name, inner_type in zip(name, _type):
            inner_type = type_convert.get(inner_type, inner_type)
            if is_option:
                result.append((f"class:{style_class}.{inner_type}", f" [{inner_name}]"))
            else:
                result.append((f"class:{style_class}.{inner_type}", f" {inner_name}"))
    else:
        raise Exception()
    return result


def compose_command_syntax(command_info, style_class="bottom-toolbar"):
    command_style = f"class:{style_class}.command"
    const_style = f"class:{style_class}.const"
    args = []
    if command_info.get("arguments"):
        for argument in command_info["arguments"]:
            if argument.get("command"):
                # command [
                args.append((command_style, " [" + argument["command"]))
                if argument.get("enum"):
                    enums = "|".join(argument["enum"])
                    args.append((const_style, f" [{enums}]"))
                elif argument.get("name"):
                    args.extend(
                        parse_argument_to_formatted_text(
                            argument["name"],
                            argument["type"],
                            argument.get("optional"),
                            style_class=style_class,
                        )
                    )
                # ]
                args.append((command_style, "]"))
            elif argument.get("enum"):
                enums = "|".join(argument["enum"])
                args.append((const_style, f" [{enums}]"))

            else:
                args.extend(
                    parse_argument_to_formatted_text(
                        argument["name"],
                        argument["type"],
                        argument.get("optional"),
                        style_class=style_class,
                    )
                )
    return args


def command_syntax(command, command_info):
    """
    Get command syntax based on redis-doc/commands.json

    :param command: Command name in uppercase
    :param command_info: dict loaded from commands.json, only for
        this command.
    """
    comamnd_group = command_info["group"]
    bottoms = [
        ("class:bottom-toolbar.group", f"({comamnd_group}) "),
        ("class:bottom-toolbar.command", f"{command}"),
    ]  # final display FormattedText

    bottoms += compose_command_syntax(command_info)

    if "since" in command_info:
        since = command_info["since"]
        bottoms.append(("class:bottom-toolbar.since", f"   since: {since}"))
    if "complexity" in command_info:
        complexity = command_info["complexity"]
        bottoms.append(("class:bottom-toolbar.complexity", f" complexity:{complexity}"))

    return FormattedText(bottoms)


def _literal_bytes(b):
    """
    convert bytes to printable text.

    backslash and double-quotes will be escaped by
    backslash.
    "hello\" -> \"hello\\\"

    we don't add outter double quotes here, since
    completer also need this function's return value
    to patch completers.

    b'hello' -> "hello"
    b'double"quotes"' -> "double\"quotes\""
    """
    s = str(b)
    s = s[2:-1]  # remove b' '
    # unescape single quote
    s = s.replace(r"\'", "'")
    return s


def ensure_str(origin, decode=None):
    """
    Ensure is string, for display and completion.

    Then add double quotes

    Note: this method do not handle nil, make sure check (nil)
          out of this method.
    """
    if origin is None:
        return None
    if isinstance(origin, str):
        return origin
    if isinstance(origin, int):
        return str(origin)
    elif isinstance(origin, list):
        return [ensure_str(b) for b in origin]
    elif isinstance(origin, bytes):
        if decode:
            return origin.decode(decode)
        return _literal_bytes(origin)
    else:
        raise Exception(f"Unkown type: {type(origin)}, origin: {origin}")


def double_quotes(unquoted):
    """
    Display String like redis-cli.
    escape inner double quotes.
    add outter double quotes.

    :param unquoted: list, or str
    """
    if isinstance(unquoted, str):
        # escape double quote
        escaped = unquoted.replace('"', '\\"')
        return f'"{escaped}"'  # add outter double quotes
    elif isinstance(unquoted, list):
        return [double_quotes(item) for item in unquoted]


def exit():
    """
    Exit IRedis REPL
    """
    print("Goodbye!")
    sys.exit()


def convert_formatted_text_to_bytes(formatted_text):
    to_render = [text for style, text in formatted_text]
    return "".join(to_render).encode()


DSN = namedtuple("DSN", "scheme host port path db username password")


def parse_url(url, db=0):
    """
    Return a Redis client object configured from the given URL

    For example::

        redis://[[username]:[password]]@localhost:6379/0
        rediss://[[username]:[password]]@localhost:6379/0
        unix://[[username]:[password]]@/path/to/socket.sock?db=0

    Three URL schemes are supported:

    - ```redis://``
      <http://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a
      normal TCP socket connection
    - ```rediss://``
      <http://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates a
      SSL wrapped TCP socket connection
    - ``unix://`` creates a Unix Domain Socket connection

    There are several ways to specify a database number. The parse function
    will return the first specified option:
        1. A ``db`` querystring option, e.g. redis://localhost?db=0
        2. If using the redis:// scheme, the path argument of the url, e.g.
           redis://localhost/0
        3. The ``db`` argument to this function.

    If none of these options are specified, db=0 is used.
    """
    url = urlparse(url)

    scheme = url.scheme
    path = unquote(url.path) if url.path else None
    # We only support redis://, rediss:// and unix:// schemes.
    # if scheme is ``unix``, read ``db`` from query string
    # otherwise read ``db`` from path
    if url.scheme == "unix":
        qs = parse_qs(url.query)
        if "db" in qs:
            db = int(qs["db"][0] or db)
    elif url.scheme in ("redis", "rediss"):
        scheme = url.scheme
        if path:
            try:
                db = int(path.replace("/", ""))
                path = None
            except (AttributeError, ValueError):
                pass
    else:
        valid_schemes = ", ".join(("redis://", "rediss://", "unix://"))
        raise ValueError(
            "Redis URL must specify one of the following" "schemes (%s)" % valid_schemes
        )

    username = unquote(url.username) if url.username else None
    password = unquote(url.password) if url.password else None
    hostname = unquote(url.hostname) if url.hostname else None
    port = url.port

    return DSN(scheme, hostname, port, path, db, username, password)