Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import inspect
  4import logging
  5import re
  6import sys
  7import typing as t
  8from collections.abc import Collection
  9from contextlib import contextmanager
 10from copy import copy
 11from enum import Enum
 12from itertools import count
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot import exp
 16    from sqlglot._typing import E, T
 17    from sqlglot.expressions import Expression
 18
 19CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 20PYTHON_VERSION = sys.version_info[:2]
 21logger = logging.getLogger("sqlglot")
 22
 23
 24class AutoName(Enum):
 25    """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
 26
 27    def _generate_next_value_(name, _start, _count, _last_values):
 28        return name
 29
 30
 31def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 32    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 33    try:
 34        return seq[index]
 35    except IndexError:
 36        return None
 37
 38
 39@t.overload
 40def ensure_list(value: t.Collection[T]) -> t.List[T]:
 41    ...
 42
 43
 44@t.overload
 45def ensure_list(value: T) -> t.List[T]:
 46    ...
 47
 48
 49def ensure_list(value):
 50    """
 51    Ensures that a value is a list, otherwise casts or wraps it into one.
 52
 53    Args:
 54        value: the value of interest.
 55
 56    Returns:
 57        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 58    """
 59    if value is None:
 60        return []
 61    if isinstance(value, (list, tuple)):
 62        return list(value)
 63
 64    return [value]
 65
 66
 67@t.overload
 68def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
 69    ...
 70
 71
 72@t.overload
 73def ensure_collection(value: T) -> t.Collection[T]:
 74    ...
 75
 76
 77def ensure_collection(value):
 78    """
 79    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 80
 81    Args:
 82        value: the value of interest.
 83
 84    Returns:
 85        The value if it's a collection, or else the value wrapped in a list.
 86    """
 87    if value is None:
 88        return []
 89    return (
 90        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
 91    )
 92
 93
 94def csv(*args: str, sep: str = ", ") -> str:
 95    """
 96    Formats any number of string arguments as CSV.
 97
 98    Args:
 99        args: the string arguments to format.
100        sep: the argument separator.
101
102    Returns:
103        The arguments formatted as a CSV string.
104    """
105    return sep.join(arg for arg in args if arg)
106
107
108def subclasses(
109    module_name: str,
110    classes: t.Type | t.Tuple[t.Type, ...],
111    exclude: t.Type | t.Tuple[t.Type, ...] = (),
112) -> t.List[t.Type]:
113    """
114    Returns all subclasses for a collection of classes, possibly excluding some of them.
115
116    Args:
117        module_name: the name of the module to search for subclasses in.
118        classes: class(es) we want to find the subclasses of.
119        exclude: class(es) we want to exclude from the returned list.
120
121    Returns:
122        The target subclasses.
123    """
124    return [
125        obj
126        for _, obj in inspect.getmembers(
127            sys.modules[module_name],
128            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
129        )
130    ]
131
132
133def apply_index_offset(
134    this: exp.Expression,
135    expressions: t.List[t.Optional[E]],
136    offset: int,
137) -> t.List[t.Optional[E]]:
138    """
139    Applies an offset to a given integer literal expression.
140
141    Args:
142        this: the target of the index
143        expressions: the expression the offset will be applied to, wrapped in a list.
144        offset: the offset that will be applied.
145
146    Returns:
147        The original expression with the offset applied to it, wrapped in a list. If the provided
148        `expressions` argument contains more than one expressions, it's returned unaffected.
149    """
150    if not offset or len(expressions) != 1:
151        return expressions
152
153    expression = expressions[0]
154
155    from sqlglot import exp
156    from sqlglot.optimizer.annotate_types import annotate_types
157    from sqlglot.optimizer.simplify import simplify
158
159    if not this.type:
160        annotate_types(this)
161
162    if t.cast(exp.DataType, this.type).this not in (
163        exp.DataType.Type.UNKNOWN,
164        exp.DataType.Type.ARRAY,
165    ):
166        return expressions
167
168    if expression:
169        if not expression.type:
170            annotate_types(expression)
171        if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
172            logger.warning("Applying array index offset (%s)", offset)
173            expression = simplify(
174                exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
175            )
176            return [expression]
177
178    return expressions
179
180
181def camel_to_snake_case(name: str) -> str:
182    """Converts `name` from camelCase to snake_case and returns the result."""
183    return CAMEL_CASE_PATTERN.sub("_", name).upper()
184
185
186def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
187    """
188    Applies a transformation to a given expression until a fix point is reached.
189
190    Args:
191        expression: the expression to be transformed.
192        func: the transformation to be applied.
193
194    Returns:
195        The transformed expression.
196    """
197    while True:
198        for n, *_ in reversed(tuple(expression.walk())):
199            n._hash = hash(n)
200        start = hash(expression)
201        expression = func(expression)
202
203        for n, *_ in expression.walk():
204            n._hash = None
205        if start == hash(expression):
206            break
207    return expression
208
209
210def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
211    """
212    Sorts a given directed acyclic graph in topological order.
213
214    Args:
215        dag: the graph to be sorted.
216
217    Returns:
218        A list that contains all of the graph's nodes in topological order.
219    """
220    result = []
221
222    for node, deps in tuple(dag.items()):
223        for dep in deps:
224            if not dep in dag:
225                dag[dep] = set()
226
227    while dag:
228        current = {node for node, deps in dag.items() if not deps}
229
230        if not current:
231            raise ValueError("Cycle error")
232
233        for node in current:
234            dag.pop(node)
235
236        for deps in dag.values():
237            deps -= current
238
239        result.extend(sorted(current))  # type: ignore
240
241    return result
242
243
244def open_file(file_name: str) -> t.TextIO:
245    """Open a file that may be compressed as gzip and return it in universal newline mode."""
246    with open(file_name, "rb") as f:
247        gzipped = f.read(2) == b"\x1f\x8b"
248
249    if gzipped:
250        import gzip
251
252        return gzip.open(file_name, "rt", newline="")
253
254    return open(file_name, encoding="utf-8", newline="")
255
256
257@contextmanager
258def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
259    """
260    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
261
262    Args:
263        read_csv: a `ReadCSV` function call
264
265    Yields:
266        A python csv reader.
267    """
268    args = read_csv.expressions
269    file = open_file(read_csv.name)
270
271    delimiter = ","
272    args = iter(arg.name for arg in args)
273    for k, v in zip(args, args):
274        if k == "delimiter":
275            delimiter = v
276
277    try:
278        import csv as csv_
279
280        yield csv_.reader(file, delimiter=delimiter)
281    finally:
282        file.close()
283
284
285def find_new_name(taken: t.Collection[str], base: str) -> str:
286    """
287    Searches for a new name.
288
289    Args:
290        taken: a collection of taken names.
291        base: base name to alter.
292
293    Returns:
294        The new, available name.
295    """
296    if base not in taken:
297        return base
298
299    i = 2
300    new = f"{base}_{i}"
301    while new in taken:
302        i += 1
303        new = f"{base}_{i}"
304
305    return new
306
307
308def name_sequence(prefix: str) -> t.Callable[[], str]:
309    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
310    sequence = count()
311    return lambda: f"{prefix}{next(sequence)}"
312
313
314def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
315    """Returns a dictionary created from an object's attributes."""
316    return {
317        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
318        **kwargs,
319    }
320
321
322def split_num_words(
323    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
324) -> t.List[t.Optional[str]]:
325    """
326    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
327
328    Args:
329        value: the value to be split.
330        sep: the value to use to split on.
331        min_num_words: the minimum number of words that are going to be in the result.
332        fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
333
334    Examples:
335        >>> split_num_words("db.table", ".", 3)
336        [None, 'db', 'table']
337        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
338        ['db', 'table', None]
339        >>> split_num_words("db.table", ".", 1)
340        ['db', 'table']
341
342    Returns:
343        The list of words returned by `split`, possibly augmented by a number of `None` values.
344    """
345    words = value.split(sep)
346    if fill_from_start:
347        return [None] * (min_num_words - len(words)) + words
348    return words + [None] * (min_num_words - len(words))
349
350
351def is_iterable(value: t.Any) -> bool:
352    """
353    Checks if the value is an iterable, excluding the types `str` and `bytes`.
354
355    Examples:
356        >>> is_iterable([1,2])
357        True
358        >>> is_iterable("test")
359        False
360
361    Args:
362        value: the value to check if it is an iterable.
363
364    Returns:
365        A `bool` value indicating if it is an iterable.
366    """
367    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
368
369
370def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
371    """
372    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
373    type `str` and `bytes` are not regarded as iterables.
374
375    Examples:
376        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
377        [1, 2, 3, 4, 5, 'bla']
378        >>> list(flatten([1, 2, 3]))
379        [1, 2, 3]
380
381    Args:
382        values: the value to be flattened.
383
384    Yields:
385        Non-iterable elements in `values`.
386    """
387    for value in values:
388        if is_iterable(value):
389            yield from flatten(value)
390        else:
391            yield value
392
393
394def dict_depth(d: t.Dict) -> int:
395    """
396    Get the nesting depth of a dictionary.
397
398    For example:
399        >>> dict_depth(None)
400        0
401        >>> dict_depth({})
402        1
403        >>> dict_depth({"a": "b"})
404        1
405        >>> dict_depth({"a": {}})
406        2
407        >>> dict_depth({"a": {"b": {}}})
408        3
409
410    Args:
411        d (dict): dictionary
412
413    Returns:
414        int: depth
415    """
416    try:
417        return 1 + dict_depth(next(iter(d.values())))
418    except AttributeError:
419        # d doesn't have attribute "values"
420        return 0
421    except StopIteration:
422        # d.values() returns an empty sequence
423        return 1
424
425
426def first(it: t.Iterable[T]) -> T:
427    """Returns the first element from an iterable.
428
429    Useful for sets.
430    """
431    return next(i for i in it)
class AutoName(enum.Enum):
25class AutoName(Enum):
26    """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
27
28    def _generate_next_value_(name, _start, _count, _last_values):
29        return name

This is used for creating enum classes where auto() is the string form of the corresponding value's name.

Inherited Members
enum.Enum
name
value
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
32def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
33    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
34    try:
35        return seq[index]
36    except IndexError:
37        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
50def ensure_list(value):
51    """
52    Ensures that a value is a list, otherwise casts or wraps it into one.
53
54    Args:
55        value: the value of interest.
56
57    Returns:
58        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
59    """
60    if value is None:
61        return []
62    if isinstance(value, (list, tuple)):
63        return list(value)
64
65    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: the value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
78def ensure_collection(value):
79    """
80    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
81
82    Args:
83        value: the value of interest.
84
85    Returns:
86        The value if it's a collection, or else the value wrapped in a list.
87    """
88    if value is None:
89        return []
90    return (
91        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
92    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: the value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
 95def csv(*args: str, sep: str = ", ") -> str:
 96    """
 97    Formats any number of string arguments as CSV.
 98
 99    Args:
100        args: the string arguments to format.
101        sep: the argument separator.
102
103    Returns:
104        The arguments formatted as a CSV string.
105    """
106    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: the string arguments to format.
  • sep: the argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
109def subclasses(
110    module_name: str,
111    classes: t.Type | t.Tuple[t.Type, ...],
112    exclude: t.Type | t.Tuple[t.Type, ...] = (),
113) -> t.List[t.Type]:
114    """
115    Returns all subclasses for a collection of classes, possibly excluding some of them.
116
117    Args:
118        module_name: the name of the module to search for subclasses in.
119        classes: class(es) we want to find the subclasses of.
120        exclude: class(es) we want to exclude from the returned list.
121
122    Returns:
123        The target subclasses.
124    """
125    return [
126        obj
127        for _, obj in inspect.getmembers(
128            sys.modules[module_name],
129            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
130        )
131    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: the name of the module to search for subclasses in.
  • classes: class(es) we want to find the subclasses of.
  • exclude: class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[Optional[~E]], offset: int) -> List[Optional[~E]]:
134def apply_index_offset(
135    this: exp.Expression,
136    expressions: t.List[t.Optional[E]],
137    offset: int,
138) -> t.List[t.Optional[E]]:
139    """
140    Applies an offset to a given integer literal expression.
141
142    Args:
143        this: the target of the index
144        expressions: the expression the offset will be applied to, wrapped in a list.
145        offset: the offset that will be applied.
146
147    Returns:
148        The original expression with the offset applied to it, wrapped in a list. If the provided
149        `expressions` argument contains more than one expressions, it's returned unaffected.
150    """
151    if not offset or len(expressions) != 1:
152        return expressions
153
154    expression = expressions[0]
155
156    from sqlglot import exp
157    from sqlglot.optimizer.annotate_types import annotate_types
158    from sqlglot.optimizer.simplify import simplify
159
160    if not this.type:
161        annotate_types(this)
162
163    if t.cast(exp.DataType, this.type).this not in (
164        exp.DataType.Type.UNKNOWN,
165        exp.DataType.Type.ARRAY,
166    ):
167        return expressions
168
169    if expression:
170        if not expression.type:
171            annotate_types(expression)
172        if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
173            logger.warning("Applying array index offset (%s)", offset)
174            expression = simplify(
175                exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
176            )
177            return [expression]
178
179    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: the target of the index
  • expressions: the expression the offset will be applied to, wrapped in a list.
  • offset: the offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expressions, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
182def camel_to_snake_case(name: str) -> str:
183    """Converts `name` from camelCase to snake_case and returns the result."""
184    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
187def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
188    """
189    Applies a transformation to a given expression until a fix point is reached.
190
191    Args:
192        expression: the expression to be transformed.
193        func: the transformation to be applied.
194
195    Returns:
196        The transformed expression.
197    """
198    while True:
199        for n, *_ in reversed(tuple(expression.walk())):
200            n._hash = hash(n)
201        start = hash(expression)
202        expression = func(expression)
203
204        for n, *_ in expression.walk():
205            n._hash = None
206        if start == hash(expression):
207            break
208    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: the expression to be transformed.
  • func: the transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
211def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
212    """
213    Sorts a given directed acyclic graph in topological order.
214
215    Args:
216        dag: the graph to be sorted.
217
218    Returns:
219        A list that contains all of the graph's nodes in topological order.
220    """
221    result = []
222
223    for node, deps in tuple(dag.items()):
224        for dep in deps:
225            if not dep in dag:
226                dag[dep] = set()
227
228    while dag:
229        current = {node for node, deps in dag.items() if not deps}
230
231        if not current:
232            raise ValueError("Cycle error")
233
234        for node in current:
235            dag.pop(node)
236
237        for deps in dag.values():
238            deps -= current
239
240        result.extend(sorted(current))  # type: ignore
241
242    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: the graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
245def open_file(file_name: str) -> t.TextIO:
246    """Open a file that may be compressed as gzip and return it in universal newline mode."""
247    with open(file_name, "rb") as f:
248        gzipped = f.read(2) == b"\x1f\x8b"
249
250    if gzipped:
251        import gzip
252
253        return gzip.open(file_name, "rt", newline="")
254
255    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
258@contextmanager
259def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
260    """
261    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
262
263    Args:
264        read_csv: a `ReadCSV` function call
265
266    Yields:
267        A python csv reader.
268    """
269    args = read_csv.expressions
270    file = open_file(read_csv.name)
271
272    delimiter = ","
273    args = iter(arg.name for arg in args)
274    for k, v in zip(args, args):
275        if k == "delimiter":
276            delimiter = v
277
278    try:
279        import csv as csv_
280
281        yield csv_.reader(file, delimiter=delimiter)
282    finally:
283        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: a ReadCSV function call
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
286def find_new_name(taken: t.Collection[str], base: str) -> str:
287    """
288    Searches for a new name.
289
290    Args:
291        taken: a collection of taken names.
292        base: base name to alter.
293
294    Returns:
295        The new, available name.
296    """
297    if base not in taken:
298        return base
299
300    i = 2
301    new = f"{base}_{i}"
302    while new in taken:
303        i += 1
304        new = f"{base}_{i}"
305
306    return new

Searches for a new name.

Arguments:
  • taken: a collection of taken names.
  • base: base name to alter.
Returns:

The new, available name.

def name_sequence(prefix: str) -> Callable[[], str]:
309def name_sequence(prefix: str) -> t.Callable[[], str]:
310    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
311    sequence = count()
312    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
315def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
316    """Returns a dictionary created from an object's attributes."""
317    return {
318        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
319        **kwargs,
320    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
323def split_num_words(
324    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
325) -> t.List[t.Optional[str]]:
326    """
327    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
328
329    Args:
330        value: the value to be split.
331        sep: the value to use to split on.
332        min_num_words: the minimum number of words that are going to be in the result.
333        fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
334
335    Examples:
336        >>> split_num_words("db.table", ".", 3)
337        [None, 'db', 'table']
338        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
339        ['db', 'table', None]
340        >>> split_num_words("db.table", ".", 1)
341        ['db', 'table']
342
343    Returns:
344        The list of words returned by `split`, possibly augmented by a number of `None` values.
345    """
346    words = value.split(sep)
347    if fill_from_start:
348        return [None] * (min_num_words - len(words)) + words
349    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: the value to be split.
  • sep: the value to use to split on.
  • min_num_words: the minimum number of words that are going to be in the result.
  • fill_from_start: indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
352def is_iterable(value: t.Any) -> bool:
353    """
354    Checks if the value is an iterable, excluding the types `str` and `bytes`.
355
356    Examples:
357        >>> is_iterable([1,2])
358        True
359        >>> is_iterable("test")
360        False
361
362    Args:
363        value: the value to check if it is an iterable.
364
365    Returns:
366        A `bool` value indicating if it is an iterable.
367    """
368    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: the value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
371def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
372    """
373    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
374    type `str` and `bytes` are not regarded as iterables.
375
376    Examples:
377        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
378        [1, 2, 3, 4, 5, 'bla']
379        >>> list(flatten([1, 2, 3]))
380        [1, 2, 3]
381
382    Args:
383        values: the value to be flattened.
384
385    Yields:
386        Non-iterable elements in `values`.
387    """
388    for value in values:
389        if is_iterable(value):
390            yield from flatten(value)
391        else:
392            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: the value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
395def dict_depth(d: t.Dict) -> int:
396    """
397    Get the nesting depth of a dictionary.
398
399    For example:
400        >>> dict_depth(None)
401        0
402        >>> dict_depth({})
403        1
404        >>> dict_depth({"a": "b"})
405        1
406        >>> dict_depth({"a": {}})
407        2
408        >>> dict_depth({"a": {"b": {}}})
409        3
410
411    Args:
412        d (dict): dictionary
413
414    Returns:
415        int: depth
416    """
417    try:
418        return 1 + dict_depth(next(iter(d.values())))
419    except AttributeError:
420        # d doesn't have attribute "values"
421        return 0
422    except StopIteration:
423        # d.values() returns an empty sequence
424        return 1

Get the nesting depth of a dictionary.

For example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
Arguments:
  • d (dict): dictionary
Returns:

int: depth

def first(it: Iterable[~T]) -> ~T:
427def first(it: t.Iterable[T]) -> T:
428    """Returns the first element from an iterable.
429
430    Useful for sets.
431    """
432    return next(i for i in it)

Returns the first element from an iterable.

Useful for sets.