Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import inspect
  4import logging
  5import re
  6import sys
  7import typing as t
  8from collections.abc import Collection
  9from contextlib import contextmanager
 10from copy import copy
 11from enum import Enum
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot import exp
 15    from sqlglot._typing import E, T
 16    from sqlglot.expressions import Expression
 17
 18CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 19PYTHON_VERSION = sys.version_info[:2]
 20logger = logging.getLogger("sqlglot")
 21
 22
 23class AutoName(Enum):
 24    """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
 25
 26    def _generate_next_value_(name, _start, _count, _last_values):
 27        return name
 28
 29
 30def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 31    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 32    try:
 33        return seq[index]
 34    except IndexError:
 35        return None
 36
 37
 38@t.overload
 39def ensure_list(value: t.Collection[T]) -> t.List[T]:
 40    ...
 41
 42
 43@t.overload
 44def ensure_list(value: T) -> t.List[T]:
 45    ...
 46
 47
 48def ensure_list(value):
 49    """
 50    Ensures that a value is a list, otherwise casts or wraps it into one.
 51
 52    Args:
 53        value: the value of interest.
 54
 55    Returns:
 56        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 57    """
 58    if value is None:
 59        return []
 60    if isinstance(value, (list, tuple)):
 61        return list(value)
 62
 63    return [value]
 64
 65
 66@t.overload
 67def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
 68    ...
 69
 70
 71@t.overload
 72def ensure_collection(value: T) -> t.Collection[T]:
 73    ...
 74
 75
 76def ensure_collection(value):
 77    """
 78    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 79
 80    Args:
 81        value: the value of interest.
 82
 83    Returns:
 84        The value if it's a collection, or else the value wrapped in a list.
 85    """
 86    if value is None:
 87        return []
 88    return (
 89        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
 90    )
 91
 92
 93def csv(*args: str, sep: str = ", ") -> str:
 94    """
 95    Formats any number of string arguments as CSV.
 96
 97    Args:
 98        args: the string arguments to format.
 99        sep: the argument separator.
100
101    Returns:
102        The arguments formatted as a CSV string.
103    """
104    return sep.join(arg for arg in args if arg)
105
106
107def subclasses(
108    module_name: str,
109    classes: t.Type | t.Tuple[t.Type, ...],
110    exclude: t.Type | t.Tuple[t.Type, ...] = (),
111) -> t.List[t.Type]:
112    """
113    Returns all subclasses for a collection of classes, possibly excluding some of them.
114
115    Args:
116        module_name: the name of the module to search for subclasses in.
117        classes: class(es) we want to find the subclasses of.
118        exclude: class(es) we want to exclude from the returned list.
119
120    Returns:
121        The target subclasses.
122    """
123    return [
124        obj
125        for _, obj in inspect.getmembers(
126            sys.modules[module_name],
127            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
128        )
129    ]
130
131
132def apply_index_offset(
133    this: exp.Expression,
134    expressions: t.List[t.Optional[E]],
135    offset: int,
136) -> t.List[t.Optional[E]]:
137    """
138    Applies an offset to a given integer literal expression.
139
140    Args:
141        this: the target of the index
142        expressions: the expression the offset will be applied to, wrapped in a list.
143        offset: the offset that will be applied.
144
145    Returns:
146        The original expression with the offset applied to it, wrapped in a list. If the provided
147        `expressions` argument contains more than one expressions, it's returned unaffected.
148    """
149    if not offset or len(expressions) != 1:
150        return expressions
151
152    expression = expressions[0]
153
154    from sqlglot import exp
155    from sqlglot.optimizer.annotate_types import annotate_types
156    from sqlglot.optimizer.simplify import simplify
157
158    if not this.type:
159        annotate_types(this)
160
161    if t.cast(exp.DataType, this.type).this not in (
162        exp.DataType.Type.UNKNOWN,
163        exp.DataType.Type.ARRAY,
164    ):
165        return expressions
166
167    if expression:
168        if not expression.type:
169            annotate_types(expression)
170        if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
171            logger.warning("Applying array index offset (%s)", offset)
172            expression = simplify(
173                exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
174            )
175            return [expression]
176
177    return expressions
178
179
180def camel_to_snake_case(name: str) -> str:
181    """Converts `name` from camelCase to snake_case and returns the result."""
182    return CAMEL_CASE_PATTERN.sub("_", name).upper()
183
184
185def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
186    """
187    Applies a transformation to a given expression until a fix point is reached.
188
189    Args:
190        expression: the expression to be transformed.
191        func: the transformation to be applied.
192
193    Returns:
194        The transformed expression.
195    """
196    while True:
197        for n, *_ in reversed(tuple(expression.walk())):
198            n._hash = hash(n)
199        start = hash(expression)
200        expression = func(expression)
201
202        for n, *_ in expression.walk():
203            n._hash = None
204        if start == hash(expression):
205            break
206    return expression
207
208
209def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
210    """
211    Sorts a given directed acyclic graph in topological order.
212
213    Args:
214        dag: the graph to be sorted.
215
216    Returns:
217        A list that contains all of the graph's nodes in topological order.
218    """
219    result = []
220
221    def visit(node: T, visited: t.Set[T]) -> None:
222        if node in result:
223            return
224        if node in visited:
225            raise ValueError("Cycle error")
226
227        visited.add(node)
228
229        for dep in dag.get(node, []):
230            visit(dep, visited)
231
232        visited.remove(node)
233        result.append(node)
234
235    for node in dag:
236        visit(node, set())
237
238    return result
239
240
241def open_file(file_name: str) -> t.TextIO:
242    """Open a file that may be compressed as gzip and return it in universal newline mode."""
243    with open(file_name, "rb") as f:
244        gzipped = f.read(2) == b"\x1f\x8b"
245
246    if gzipped:
247        import gzip
248
249        return gzip.open(file_name, "rt", newline="")
250
251    return open(file_name, encoding="utf-8", newline="")
252
253
254@contextmanager
255def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
256    """
257    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
258
259    Args:
260        read_csv: a `ReadCSV` function call
261
262    Yields:
263        A python csv reader.
264    """
265    args = read_csv.expressions
266    file = open_file(read_csv.name)
267
268    delimiter = ","
269    args = iter(arg.name for arg in args)
270    for k, v in zip(args, args):
271        if k == "delimiter":
272            delimiter = v
273
274    try:
275        import csv as csv_
276
277        yield csv_.reader(file, delimiter=delimiter)
278    finally:
279        file.close()
280
281
282def find_new_name(taken: t.Collection[str], base: str) -> str:
283    """
284    Searches for a new name.
285
286    Args:
287        taken: a collection of taken names.
288        base: base name to alter.
289
290    Returns:
291        The new, available name.
292    """
293    if base not in taken:
294        return base
295
296    i = 2
297    new = f"{base}_{i}"
298    while new in taken:
299        i += 1
300        new = f"{base}_{i}"
301
302    return new
303
304
305def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
306    """Returns a dictionary created from an object's attributes."""
307    return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
308
309
310def split_num_words(
311    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
312) -> t.List[t.Optional[str]]:
313    """
314    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
315
316    Args:
317        value: the value to be split.
318        sep: the value to use to split on.
319        min_num_words: the minimum number of words that are going to be in the result.
320        fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
321
322    Examples:
323        >>> split_num_words("db.table", ".", 3)
324        [None, 'db', 'table']
325        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
326        ['db', 'table', None]
327        >>> split_num_words("db.table", ".", 1)
328        ['db', 'table']
329
330    Returns:
331        The list of words returned by `split`, possibly augmented by a number of `None` values.
332    """
333    words = value.split(sep)
334    if fill_from_start:
335        return [None] * (min_num_words - len(words)) + words
336    return words + [None] * (min_num_words - len(words))
337
338
339def is_iterable(value: t.Any) -> bool:
340    """
341    Checks if the value is an iterable, excluding the types `str` and `bytes`.
342
343    Examples:
344        >>> is_iterable([1,2])
345        True
346        >>> is_iterable("test")
347        False
348
349    Args:
350        value: the value to check if it is an iterable.
351
352    Returns:
353        A `bool` value indicating if it is an iterable.
354    """
355    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
356
357
358def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
359    """
360    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
361    type `str` and `bytes` are not regarded as iterables.
362
363    Examples:
364        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
365        [1, 2, 3, 4, 5, 'bla']
366        >>> list(flatten([1, 2, 3]))
367        [1, 2, 3]
368
369    Args:
370        values: the value to be flattened.
371
372    Yields:
373        Non-iterable elements in `values`.
374    """
375    for value in values:
376        if is_iterable(value):
377            yield from flatten(value)
378        else:
379            yield value
380
381
382def dict_depth(d: t.Dict) -> int:
383    """
384    Get the nesting depth of a dictionary.
385
386    For example:
387        >>> dict_depth(None)
388        0
389        >>> dict_depth({})
390        1
391        >>> dict_depth({"a": "b"})
392        1
393        >>> dict_depth({"a": {}})
394        2
395        >>> dict_depth({"a": {"b": {}}})
396        3
397
398    Args:
399        d (dict): dictionary
400
401    Returns:
402        int: depth
403    """
404    try:
405        return 1 + dict_depth(next(iter(d.values())))
406    except AttributeError:
407        # d doesn't have attribute "values"
408        return 0
409    except StopIteration:
410        # d.values() returns an empty sequence
411        return 1
412
413
414def first(it: t.Iterable[T]) -> T:
415    """Returns the first element from an iterable.
416
417    Useful for sets.
418    """
419    return next(i for i in it)
420
421
422def should_identify(text: str, identify: str | bool) -> bool:
423    """Checks if text should be identified given an identify option.
424
425    Args:
426        text: the text to check.
427        identify: "always" | True - always returns true, "safe" - true if no upper case
428
429    Returns:
430        Whether or not a string should be identified.
431    """
432    if identify is True or identify == "always":
433        return True
434    if identify == "safe":
435        return not any(char.isupper() for char in text)
436    return False
class AutoName(enum.Enum):
24class AutoName(Enum):
25    """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
26
27    def _generate_next_value_(name, _start, _count, _last_values):
28        return name

This is used for creating enum classes where auto() is the string form of the corresponding value's name.

Inherited Members
enum.Enum
name
value
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
31def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
32    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
33    try:
34        return seq[index]
35    except IndexError:
36        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
49def ensure_list(value):
50    """
51    Ensures that a value is a list, otherwise casts or wraps it into one.
52
53    Args:
54        value: the value of interest.
55
56    Returns:
57        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
58    """
59    if value is None:
60        return []
61    if isinstance(value, (list, tuple)):
62        return list(value)
63
64    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: the value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
77def ensure_collection(value):
78    """
79    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
80
81    Args:
82        value: the value of interest.
83
84    Returns:
85        The value if it's a collection, or else the value wrapped in a list.
86    """
87    if value is None:
88        return []
89    return (
90        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
91    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: the value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
 94def csv(*args: str, sep: str = ", ") -> str:
 95    """
 96    Formats any number of string arguments as CSV.
 97
 98    Args:
 99        args: the string arguments to format.
100        sep: the argument separator.
101
102    Returns:
103        The arguments formatted as a CSV string.
104    """
105    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: the string arguments to format.
  • sep: the argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
108def subclasses(
109    module_name: str,
110    classes: t.Type | t.Tuple[t.Type, ...],
111    exclude: t.Type | t.Tuple[t.Type, ...] = (),
112) -> t.List[t.Type]:
113    """
114    Returns all subclasses for a collection of classes, possibly excluding some of them.
115
116    Args:
117        module_name: the name of the module to search for subclasses in.
118        classes: class(es) we want to find the subclasses of.
119        exclude: class(es) we want to exclude from the returned list.
120
121    Returns:
122        The target subclasses.
123    """
124    return [
125        obj
126        for _, obj in inspect.getmembers(
127            sys.modules[module_name],
128            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
129        )
130    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: the name of the module to search for subclasses in.
  • classes: class(es) we want to find the subclasses of.
  • exclude: class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[Optional[~E]], offset: int) -> List[Optional[~E]]:
133def apply_index_offset(
134    this: exp.Expression,
135    expressions: t.List[t.Optional[E]],
136    offset: int,
137) -> t.List[t.Optional[E]]:
138    """
139    Applies an offset to a given integer literal expression.
140
141    Args:
142        this: the target of the index
143        expressions: the expression the offset will be applied to, wrapped in a list.
144        offset: the offset that will be applied.
145
146    Returns:
147        The original expression with the offset applied to it, wrapped in a list. If the provided
148        `expressions` argument contains more than one expressions, it's returned unaffected.
149    """
150    if not offset or len(expressions) != 1:
151        return expressions
152
153    expression = expressions[0]
154
155    from sqlglot import exp
156    from sqlglot.optimizer.annotate_types import annotate_types
157    from sqlglot.optimizer.simplify import simplify
158
159    if not this.type:
160        annotate_types(this)
161
162    if t.cast(exp.DataType, this.type).this not in (
163        exp.DataType.Type.UNKNOWN,
164        exp.DataType.Type.ARRAY,
165    ):
166        return expressions
167
168    if expression:
169        if not expression.type:
170            annotate_types(expression)
171        if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
172            logger.warning("Applying array index offset (%s)", offset)
173            expression = simplify(
174                exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
175            )
176            return [expression]
177
178    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: the target of the index
  • expressions: the expression the offset will be applied to, wrapped in a list.
  • offset: the offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expressions, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
181def camel_to_snake_case(name: str) -> str:
182    """Converts `name` from camelCase to snake_case and returns the result."""
183    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
186def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
187    """
188    Applies a transformation to a given expression until a fix point is reached.
189
190    Args:
191        expression: the expression to be transformed.
192        func: the transformation to be applied.
193
194    Returns:
195        The transformed expression.
196    """
197    while True:
198        for n, *_ in reversed(tuple(expression.walk())):
199            n._hash = hash(n)
200        start = hash(expression)
201        expression = func(expression)
202
203        for n, *_ in expression.walk():
204            n._hash = None
205        if start == hash(expression):
206            break
207    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: the expression to be transformed.
  • func: the transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, List[~T]]) -> List[~T]:
210def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
211    """
212    Sorts a given directed acyclic graph in topological order.
213
214    Args:
215        dag: the graph to be sorted.
216
217    Returns:
218        A list that contains all of the graph's nodes in topological order.
219    """
220    result = []
221
222    def visit(node: T, visited: t.Set[T]) -> None:
223        if node in result:
224            return
225        if node in visited:
226            raise ValueError("Cycle error")
227
228        visited.add(node)
229
230        for dep in dag.get(node, []):
231            visit(dep, visited)
232
233        visited.remove(node)
234        result.append(node)
235
236    for node in dag:
237        visit(node, set())
238
239    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: the graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
242def open_file(file_name: str) -> t.TextIO:
243    """Open a file that may be compressed as gzip and return it in universal newline mode."""
244    with open(file_name, "rb") as f:
245        gzipped = f.read(2) == b"\x1f\x8b"
246
247    if gzipped:
248        import gzip
249
250        return gzip.open(file_name, "rt", newline="")
251
252    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
255@contextmanager
256def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
257    """
258    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
259
260    Args:
261        read_csv: a `ReadCSV` function call
262
263    Yields:
264        A python csv reader.
265    """
266    args = read_csv.expressions
267    file = open_file(read_csv.name)
268
269    delimiter = ","
270    args = iter(arg.name for arg in args)
271    for k, v in zip(args, args):
272        if k == "delimiter":
273            delimiter = v
274
275    try:
276        import csv as csv_
277
278        yield csv_.reader(file, delimiter=delimiter)
279    finally:
280        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: a ReadCSV function call
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
283def find_new_name(taken: t.Collection[str], base: str) -> str:
284    """
285    Searches for a new name.
286
287    Args:
288        taken: a collection of taken names.
289        base: base name to alter.
290
291    Returns:
292        The new, available name.
293    """
294    if base not in taken:
295        return base
296
297    i = 2
298    new = f"{base}_{i}"
299    while new in taken:
300        i += 1
301        new = f"{base}_{i}"
302
303    return new

Searches for a new name.

Arguments:
  • taken: a collection of taken names.
  • base: base name to alter.
Returns:

The new, available name.

def object_to_dict(obj: Any, **kwargs) -> Dict:
306def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
307    """Returns a dictionary created from an object's attributes."""
308    return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
311def split_num_words(
312    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
313) -> t.List[t.Optional[str]]:
314    """
315    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
316
317    Args:
318        value: the value to be split.
319        sep: the value to use to split on.
320        min_num_words: the minimum number of words that are going to be in the result.
321        fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
322
323    Examples:
324        >>> split_num_words("db.table", ".", 3)
325        [None, 'db', 'table']
326        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
327        ['db', 'table', None]
328        >>> split_num_words("db.table", ".", 1)
329        ['db', 'table']
330
331    Returns:
332        The list of words returned by `split`, possibly augmented by a number of `None` values.
333    """
334    words = value.split(sep)
335    if fill_from_start:
336        return [None] * (min_num_words - len(words)) + words
337    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: the value to be split.
  • sep: the value to use to split on.
  • min_num_words: the minimum number of words that are going to be in the result.
  • fill_from_start: indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
340def is_iterable(value: t.Any) -> bool:
341    """
342    Checks if the value is an iterable, excluding the types `str` and `bytes`.
343
344    Examples:
345        >>> is_iterable([1,2])
346        True
347        >>> is_iterable("test")
348        False
349
350    Args:
351        value: the value to check if it is an iterable.
352
353    Returns:
354        A `bool` value indicating if it is an iterable.
355    """
356    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: the value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
359def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
360    """
361    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
362    type `str` and `bytes` are not regarded as iterables.
363
364    Examples:
365        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
366        [1, 2, 3, 4, 5, 'bla']
367        >>> list(flatten([1, 2, 3]))
368        [1, 2, 3]
369
370    Args:
371        values: the value to be flattened.
372
373    Yields:
374        Non-iterable elements in `values`.
375    """
376    for value in values:
377        if is_iterable(value):
378            yield from flatten(value)
379        else:
380            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: the value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
383def dict_depth(d: t.Dict) -> int:
384    """
385    Get the nesting depth of a dictionary.
386
387    For example:
388        >>> dict_depth(None)
389        0
390        >>> dict_depth({})
391        1
392        >>> dict_depth({"a": "b"})
393        1
394        >>> dict_depth({"a": {}})
395        2
396        >>> dict_depth({"a": {"b": {}}})
397        3
398
399    Args:
400        d (dict): dictionary
401
402    Returns:
403        int: depth
404    """
405    try:
406        return 1 + dict_depth(next(iter(d.values())))
407    except AttributeError:
408        # d doesn't have attribute "values"
409        return 0
410    except StopIteration:
411        # d.values() returns an empty sequence
412        return 1

Get the nesting depth of a dictionary.

For example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
Arguments:
  • d (dict): dictionary
Returns:

int: depth

def first(it: Iterable[~T]) -> ~T:
415def first(it: t.Iterable[T]) -> T:
416    """Returns the first element from an iterable.
417
418    Useful for sets.
419    """
420    return next(i for i in it)

Returns the first element from an iterable.

Useful for sets.

def should_identify(text: str, identify: str | bool) -> bool:
423def should_identify(text: str, identify: str | bool) -> bool:
424    """Checks if text should be identified given an identify option.
425
426    Args:
427        text: the text to check.
428        identify: "always" | True - always returns true, "safe" - true if no upper case
429
430    Returns:
431        Whether or not a string should be identified.
432    """
433    if identify is True or identify == "always":
434        return True
435    if identify == "safe":
436        return not any(char.isupper() for char in text)
437    return False

Checks if text should be identified given an identify option.

Arguments:
  • text: the text to check.
  • identify: "always" | True - always returns true, "safe" - true if no upper case
Returns:

Whether or not a string should be identified.