Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import inspect
  4import logging
  5import re
  6import sys
  7import typing as t
  8from collections.abc import Collection
  9from contextlib import contextmanager
 10from copy import copy
 11from enum import Enum
 12from itertools import count
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot import exp
 16    from sqlglot._typing import E, T
 17    from sqlglot.expressions import Expression
 18
 19CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 20PYTHON_VERSION = sys.version_info[:2]
 21logger = logging.getLogger("sqlglot")
 22
 23
 24class AutoName(Enum):
 25    """
 26    This is used for creating Enum classes where `auto()` is the string form
 27    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 28
 29    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 30    """
 31
 32    def _generate_next_value_(name, _start, _count, _last_values):
 33        return name
 34
 35
 36def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 37    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 38    try:
 39        return seq[index]
 40    except IndexError:
 41        return None
 42
 43
 44@t.overload
 45def ensure_list(value: t.Collection[T]) -> t.List[T]:
 46    ...
 47
 48
 49@t.overload
 50def ensure_list(value: T) -> t.List[T]:
 51    ...
 52
 53
 54def ensure_list(value):
 55    """
 56    Ensures that a value is a list, otherwise casts or wraps it into one.
 57
 58    Args:
 59        value: The value of interest.
 60
 61    Returns:
 62        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 63    """
 64    if value is None:
 65        return []
 66    if isinstance(value, (list, tuple)):
 67        return list(value)
 68
 69    return [value]
 70
 71
 72@t.overload
 73def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
 74    ...
 75
 76
 77@t.overload
 78def ensure_collection(value: T) -> t.Collection[T]:
 79    ...
 80
 81
 82def ensure_collection(value):
 83    """
 84    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 85
 86    Args:
 87        value: The value of interest.
 88
 89    Returns:
 90        The value if it's a collection, or else the value wrapped in a list.
 91    """
 92    if value is None:
 93        return []
 94    return (
 95        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
 96    )
 97
 98
 99def csv(*args: str, sep: str = ", ") -> str:
100    """
101    Formats any number of string arguments as CSV.
102
103    Args:
104        args: The string arguments to format.
105        sep: The argument separator.
106
107    Returns:
108        The arguments formatted as a CSV string.
109    """
110    return sep.join(arg for arg in args if arg)
111
112
113def subclasses(
114    module_name: str,
115    classes: t.Type | t.Tuple[t.Type, ...],
116    exclude: t.Type | t.Tuple[t.Type, ...] = (),
117) -> t.List[t.Type]:
118    """
119    Returns all subclasses for a collection of classes, possibly excluding some of them.
120
121    Args:
122        module_name: The name of the module to search for subclasses in.
123        classes: Class(es) we want to find the subclasses of.
124        exclude: Class(es) we want to exclude from the returned list.
125
126    Returns:
127        The target subclasses.
128    """
129    return [
130        obj
131        for _, obj in inspect.getmembers(
132            sys.modules[module_name],
133            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
134        )
135    ]
136
137
138def apply_index_offset(
139    this: exp.Expression,
140    expressions: t.List[t.Optional[E]],
141    offset: int,
142) -> t.List[t.Optional[E]]:
143    """
144    Applies an offset to a given integer literal expression.
145
146    Args:
147        this: The target of the index.
148        expressions: The expression the offset will be applied to, wrapped in a list.
149        offset: The offset that will be applied.
150
151    Returns:
152        The original expression with the offset applied to it, wrapped in a list. If the provided
153        `expressions` argument contains more than one expression, it's returned unaffected.
154    """
155    if not offset or len(expressions) != 1:
156        return expressions
157
158    expression = expressions[0]
159
160    from sqlglot import exp
161    from sqlglot.optimizer.annotate_types import annotate_types
162    from sqlglot.optimizer.simplify import simplify
163
164    if not this.type:
165        annotate_types(this)
166
167    if t.cast(exp.DataType, this.type).this not in (
168        exp.DataType.Type.UNKNOWN,
169        exp.DataType.Type.ARRAY,
170    ):
171        return expressions
172
173    if expression:
174        if not expression.type:
175            annotate_types(expression)
176        if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
177            logger.warning("Applying array index offset (%s)", offset)
178            expression = simplify(
179                exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
180            )
181            return [expression]
182
183    return expressions
184
185
186def camel_to_snake_case(name: str) -> str:
187    """Converts `name` from camelCase to snake_case and returns the result."""
188    return CAMEL_CASE_PATTERN.sub("_", name).upper()
189
190
191def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
192    """
193    Applies a transformation to a given expression until a fix point is reached.
194
195    Args:
196        expression: The expression to be transformed.
197        func: The transformation to be applied.
198
199    Returns:
200        The transformed expression.
201    """
202    while True:
203        for n, *_ in reversed(tuple(expression.walk())):
204            n._hash = hash(n)
205
206        start = hash(expression)
207        expression = func(expression)
208
209        for n, *_ in expression.walk():
210            n._hash = None
211        if start == hash(expression):
212            break
213
214    return expression
215
216
217def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
218    """
219    Sorts a given directed acyclic graph in topological order.
220
221    Args:
222        dag: The graph to be sorted.
223
224    Returns:
225        A list that contains all of the graph's nodes in topological order.
226    """
227    result = []
228
229    for node, deps in tuple(dag.items()):
230        for dep in deps:
231            if not dep in dag:
232                dag[dep] = set()
233
234    while dag:
235        current = {node for node, deps in dag.items() if not deps}
236
237        if not current:
238            raise ValueError("Cycle error")
239
240        for node in current:
241            dag.pop(node)
242
243        for deps in dag.values():
244            deps -= current
245
246        result.extend(sorted(current))  # type: ignore
247
248    return result
249
250
251def open_file(file_name: str) -> t.TextIO:
252    """Open a file that may be compressed as gzip and return it in universal newline mode."""
253    with open(file_name, "rb") as f:
254        gzipped = f.read(2) == b"\x1f\x8b"
255
256    if gzipped:
257        import gzip
258
259        return gzip.open(file_name, "rt", newline="")
260
261    return open(file_name, encoding="utf-8", newline="")
262
263
264@contextmanager
265def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
266    """
267    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
268
269    Args:
270        read_csv: A `ReadCSV` function call.
271
272    Yields:
273        A python csv reader.
274    """
275    args = read_csv.expressions
276    file = open_file(read_csv.name)
277
278    delimiter = ","
279    args = iter(arg.name for arg in args)
280    for k, v in zip(args, args):
281        if k == "delimiter":
282            delimiter = v
283
284    try:
285        import csv as csv_
286
287        yield csv_.reader(file, delimiter=delimiter)
288    finally:
289        file.close()
290
291
292def find_new_name(taken: t.Collection[str], base: str) -> str:
293    """
294    Searches for a new name.
295
296    Args:
297        taken: A collection of taken names.
298        base: Base name to alter.
299
300    Returns:
301        The new, available name.
302    """
303    if base not in taken:
304        return base
305
306    i = 2
307    new = f"{base}_{i}"
308    while new in taken:
309        i += 1
310        new = f"{base}_{i}"
311
312    return new
313
314
315def name_sequence(prefix: str) -> t.Callable[[], str]:
316    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
317    sequence = count()
318    return lambda: f"{prefix}{next(sequence)}"
319
320
321def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
322    """Returns a dictionary created from an object's attributes."""
323    return {
324        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
325        **kwargs,
326    }
327
328
329def split_num_words(
330    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
331) -> t.List[t.Optional[str]]:
332    """
333    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
334
335    Args:
336        value: The value to be split.
337        sep: The value to use to split on.
338        min_num_words: The minimum number of words that are going to be in the result.
339        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
340
341    Examples:
342        >>> split_num_words("db.table", ".", 3)
343        [None, 'db', 'table']
344        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
345        ['db', 'table', None]
346        >>> split_num_words("db.table", ".", 1)
347        ['db', 'table']
348
349    Returns:
350        The list of words returned by `split`, possibly augmented by a number of `None` values.
351    """
352    words = value.split(sep)
353    if fill_from_start:
354        return [None] * (min_num_words - len(words)) + words
355    return words + [None] * (min_num_words - len(words))
356
357
358def is_iterable(value: t.Any) -> bool:
359    """
360    Checks if the value is an iterable, excluding the types `str` and `bytes`.
361
362    Examples:
363        >>> is_iterable([1,2])
364        True
365        >>> is_iterable("test")
366        False
367
368    Args:
369        value: The value to check if it is an iterable.
370
371    Returns:
372        A `bool` value indicating if it is an iterable.
373    """
374    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
375
376
377def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
378    """
379    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
380    type `str` and `bytes` are not regarded as iterables.
381
382    Examples:
383        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
384        [1, 2, 3, 4, 5, 'bla']
385        >>> list(flatten([1, 2, 3]))
386        [1, 2, 3]
387
388    Args:
389        values: The value to be flattened.
390
391    Yields:
392        Non-iterable elements in `values`.
393    """
394    for value in values:
395        if is_iterable(value):
396            yield from flatten(value)
397        else:
398            yield value
399
400
401def dict_depth(d: t.Dict) -> int:
402    """
403    Get the nesting depth of a dictionary.
404
405    Example:
406        >>> dict_depth(None)
407        0
408        >>> dict_depth({})
409        1
410        >>> dict_depth({"a": "b"})
411        1
412        >>> dict_depth({"a": {}})
413        2
414        >>> dict_depth({"a": {"b": {}}})
415        3
416    """
417    try:
418        return 1 + dict_depth(next(iter(d.values())))
419    except AttributeError:
420        # d doesn't have attribute "values"
421        return 0
422    except StopIteration:
423        # d.values() returns an empty sequence
424        return 1
425
426
427def first(it: t.Iterable[T]) -> T:
428    """Returns the first element from an iterable (useful for sets)."""
429    return next(i for i in it)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
25class AutoName(Enum):
26    """
27    This is used for creating Enum classes where `auto()` is the string form
28    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
29
30    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
31    """
32
33    def _generate_next_value_(name, _start, _count, _last_values):
34        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
37def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
38    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
39    try:
40        return seq[index]
41    except IndexError:
42        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
55def ensure_list(value):
56    """
57    Ensures that a value is a list, otherwise casts or wraps it into one.
58
59    Args:
60        value: The value of interest.
61
62    Returns:
63        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
64    """
65    if value is None:
66        return []
67    if isinstance(value, (list, tuple)):
68        return list(value)
69
70    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
83def ensure_collection(value):
84    """
85    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
86
87    Args:
88        value: The value of interest.
89
90    Returns:
91        The value if it's a collection, or else the value wrapped in a list.
92    """
93    if value is None:
94        return []
95    return (
96        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
97    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
100def csv(*args: str, sep: str = ", ") -> str:
101    """
102    Formats any number of string arguments as CSV.
103
104    Args:
105        args: The string arguments to format.
106        sep: The argument separator.
107
108    Returns:
109        The arguments formatted as a CSV string.
110    """
111    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
114def subclasses(
115    module_name: str,
116    classes: t.Type | t.Tuple[t.Type, ...],
117    exclude: t.Type | t.Tuple[t.Type, ...] = (),
118) -> t.List[t.Type]:
119    """
120    Returns all subclasses for a collection of classes, possibly excluding some of them.
121
122    Args:
123        module_name: The name of the module to search for subclasses in.
124        classes: Class(es) we want to find the subclasses of.
125        exclude: Class(es) we want to exclude from the returned list.
126
127    Returns:
128        The target subclasses.
129    """
130    return [
131        obj
132        for _, obj in inspect.getmembers(
133            sys.modules[module_name],
134            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
135        )
136    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[Optional[~E]], offset: int) -> List[Optional[~E]]:
139def apply_index_offset(
140    this: exp.Expression,
141    expressions: t.List[t.Optional[E]],
142    offset: int,
143) -> t.List[t.Optional[E]]:
144    """
145    Applies an offset to a given integer literal expression.
146
147    Args:
148        this: The target of the index.
149        expressions: The expression the offset will be applied to, wrapped in a list.
150        offset: The offset that will be applied.
151
152    Returns:
153        The original expression with the offset applied to it, wrapped in a list. If the provided
154        `expressions` argument contains more than one expression, it's returned unaffected.
155    """
156    if not offset or len(expressions) != 1:
157        return expressions
158
159    expression = expressions[0]
160
161    from sqlglot import exp
162    from sqlglot.optimizer.annotate_types import annotate_types
163    from sqlglot.optimizer.simplify import simplify
164
165    if not this.type:
166        annotate_types(this)
167
168    if t.cast(exp.DataType, this.type).this not in (
169        exp.DataType.Type.UNKNOWN,
170        exp.DataType.Type.ARRAY,
171    ):
172        return expressions
173
174    if expression:
175        if not expression.type:
176            annotate_types(expression)
177        if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
178            logger.warning("Applying array index offset (%s)", offset)
179            expression = simplify(
180                exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
181            )
182            return [expression]
183
184    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
187def camel_to_snake_case(name: str) -> str:
188    """Converts `name` from camelCase to snake_case and returns the result."""
189    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
192def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
193    """
194    Applies a transformation to a given expression until a fix point is reached.
195
196    Args:
197        expression: The expression to be transformed.
198        func: The transformation to be applied.
199
200    Returns:
201        The transformed expression.
202    """
203    while True:
204        for n, *_ in reversed(tuple(expression.walk())):
205            n._hash = hash(n)
206
207        start = hash(expression)
208        expression = func(expression)
209
210        for n, *_ in expression.walk():
211            n._hash = None
212        if start == hash(expression):
213            break
214
215    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
218def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
219    """
220    Sorts a given directed acyclic graph in topological order.
221
222    Args:
223        dag: The graph to be sorted.
224
225    Returns:
226        A list that contains all of the graph's nodes in topological order.
227    """
228    result = []
229
230    for node, deps in tuple(dag.items()):
231        for dep in deps:
232            if not dep in dag:
233                dag[dep] = set()
234
235    while dag:
236        current = {node for node, deps in dag.items() if not deps}
237
238        if not current:
239            raise ValueError("Cycle error")
240
241        for node in current:
242            dag.pop(node)
243
244        for deps in dag.values():
245            deps -= current
246
247        result.extend(sorted(current))  # type: ignore
248
249    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
252def open_file(file_name: str) -> t.TextIO:
253    """Open a file that may be compressed as gzip and return it in universal newline mode."""
254    with open(file_name, "rb") as f:
255        gzipped = f.read(2) == b"\x1f\x8b"
256
257    if gzipped:
258        import gzip
259
260        return gzip.open(file_name, "rt", newline="")
261
262    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
265@contextmanager
266def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
267    """
268    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
269
270    Args:
271        read_csv: A `ReadCSV` function call.
272
273    Yields:
274        A python csv reader.
275    """
276    args = read_csv.expressions
277    file = open_file(read_csv.name)
278
279    delimiter = ","
280    args = iter(arg.name for arg in args)
281    for k, v in zip(args, args):
282        if k == "delimiter":
283            delimiter = v
284
285    try:
286        import csv as csv_
287
288        yield csv_.reader(file, delimiter=delimiter)
289    finally:
290        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
293def find_new_name(taken: t.Collection[str], base: str) -> str:
294    """
295    Searches for a new name.
296
297    Args:
298        taken: A collection of taken names.
299        base: Base name to alter.
300
301    Returns:
302        The new, available name.
303    """
304    if base not in taken:
305        return base
306
307    i = 2
308    new = f"{base}_{i}"
309    while new in taken:
310        i += 1
311        new = f"{base}_{i}"
312
313    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def name_sequence(prefix: str) -> Callable[[], str]:
316def name_sequence(prefix: str) -> t.Callable[[], str]:
317    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
318    sequence = count()
319    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
322def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
323    """Returns a dictionary created from an object's attributes."""
324    return {
325        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
326        **kwargs,
327    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
330def split_num_words(
331    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
332) -> t.List[t.Optional[str]]:
333    """
334    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
335
336    Args:
337        value: The value to be split.
338        sep: The value to use to split on.
339        min_num_words: The minimum number of words that are going to be in the result.
340        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
341
342    Examples:
343        >>> split_num_words("db.table", ".", 3)
344        [None, 'db', 'table']
345        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
346        ['db', 'table', None]
347        >>> split_num_words("db.table", ".", 1)
348        ['db', 'table']
349
350    Returns:
351        The list of words returned by `split`, possibly augmented by a number of `None` values.
352    """
353    words = value.split(sep)
354    if fill_from_start:
355        return [None] * (min_num_words - len(words)) + words
356    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
359def is_iterable(value: t.Any) -> bool:
360    """
361    Checks if the value is an iterable, excluding the types `str` and `bytes`.
362
363    Examples:
364        >>> is_iterable([1,2])
365        True
366        >>> is_iterable("test")
367        False
368
369    Args:
370        value: The value to check if it is an iterable.
371
372    Returns:
373        A `bool` value indicating if it is an iterable.
374    """
375    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
378def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
379    """
380    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
381    type `str` and `bytes` are not regarded as iterables.
382
383    Examples:
384        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
385        [1, 2, 3, 4, 5, 'bla']
386        >>> list(flatten([1, 2, 3]))
387        [1, 2, 3]
388
389    Args:
390        values: The value to be flattened.
391
392    Yields:
393        Non-iterable elements in `values`.
394    """
395    for value in values:
396        if is_iterable(value):
397            yield from flatten(value)
398        else:
399            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
402def dict_depth(d: t.Dict) -> int:
403    """
404    Get the nesting depth of a dictionary.
405
406    Example:
407        >>> dict_depth(None)
408        0
409        >>> dict_depth({})
410        1
411        >>> dict_depth({"a": "b"})
412        1
413        >>> dict_depth({"a": {}})
414        2
415        >>> dict_depth({"a": {"b": {}}})
416        3
417    """
418    try:
419        return 1 + dict_depth(next(iter(d.values())))
420    except AttributeError:
421        # d doesn't have attribute "values"
422        return 0
423    except StopIteration:
424        # d.values() returns an empty sequence
425        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
428def first(it: t.Iterable[T]) -> T:
429    """Returns the first element from an iterable (useful for sets)."""
430    return next(i for i in it)

Returns the first element from an iterable (useful for sets).