Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]:
 57    ...
 58
 59
 60@t.overload
 61def ensure_list(value: T) -> t.List[T]:
 62    ...
 63
 64
 65def ensure_list(value):
 66    """
 67    Ensures that a value is a list, otherwise casts or wraps it into one.
 68
 69    Args:
 70        value: The value of interest.
 71
 72    Returns:
 73        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 74    """
 75    if value is None:
 76        return []
 77    if isinstance(value, (list, tuple)):
 78        return list(value)
 79
 80    return [value]
 81
 82
 83@t.overload
 84def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
 85    ...
 86
 87
 88@t.overload
 89def ensure_collection(value: T) -> t.Collection[T]:
 90    ...
 91
 92
 93def ensure_collection(value):
 94    """
 95    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 96
 97    Args:
 98        value: The value of interest.
 99
100    Returns:
101        The value if it's a collection, or else the value wrapped in a list.
102    """
103    if value is None:
104        return []
105    return (
106        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
107    )
108
109
110def csv(*args: str, sep: str = ", ") -> str:
111    """
112    Formats any number of string arguments as CSV.
113
114    Args:
115        args: The string arguments to format.
116        sep: The argument separator.
117
118    Returns:
119        The arguments formatted as a CSV string.
120    """
121    return sep.join(arg for arg in args if arg)
122
123
124def subclasses(
125    module_name: str,
126    classes: t.Type | t.Tuple[t.Type, ...],
127    exclude: t.Type | t.Tuple[t.Type, ...] = (),
128) -> t.List[t.Type]:
129    """
130    Returns all subclasses for a collection of classes, possibly excluding some of them.
131
132    Args:
133        module_name: The name of the module to search for subclasses in.
134        classes: Class(es) we want to find the subclasses of.
135        exclude: Class(es) we want to exclude from the returned list.
136
137    Returns:
138        The target subclasses.
139    """
140    return [
141        obj
142        for _, obj in inspect.getmembers(
143            sys.modules[module_name],
144            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
145        )
146    ]
147
148
149def apply_index_offset(
150    this: exp.Expression,
151    expressions: t.List[E],
152    offset: int,
153) -> t.List[E]:
154    """
155    Applies an offset to a given integer literal expression.
156
157    Args:
158        this: The target of the index.
159        expressions: The expression the offset will be applied to, wrapped in a list.
160        offset: The offset that will be applied.
161
162    Returns:
163        The original expression with the offset applied to it, wrapped in a list. If the provided
164        `expressions` argument contains more than one expression, it's returned unaffected.
165    """
166    if not offset or len(expressions) != 1:
167        return expressions
168
169    expression = expressions[0]
170
171    from sqlglot import exp
172    from sqlglot.optimizer.annotate_types import annotate_types
173    from sqlglot.optimizer.simplify import simplify
174
175    if not this.type:
176        annotate_types(this)
177
178    if t.cast(exp.DataType, this.type).this not in (
179        exp.DataType.Type.UNKNOWN,
180        exp.DataType.Type.ARRAY,
181    ):
182        return expressions
183
184    if not expression.type:
185        annotate_types(expression)
186    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
187        logger.warning("Applying array index offset (%s)", offset)
188        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
189        return [expression]
190
191    return expressions
192
193
194def camel_to_snake_case(name: str) -> str:
195    """Converts `name` from camelCase to snake_case and returns the result."""
196    return CAMEL_CASE_PATTERN.sub("_", name).upper()
197
198
199def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
200    """
201    Applies a transformation to a given expression until a fix point is reached.
202
203    Args:
204        expression: The expression to be transformed.
205        func: The transformation to be applied.
206
207    Returns:
208        The transformed expression.
209    """
210    while True:
211        for n, *_ in reversed(tuple(expression.walk())):
212            n._hash = hash(n)
213
214        start = hash(expression)
215        expression = func(expression)
216
217        for n, *_ in expression.walk():
218            n._hash = None
219        if start == hash(expression):
220            break
221
222    return expression
223
224
225def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
226    """
227    Sorts a given directed acyclic graph in topological order.
228
229    Args:
230        dag: The graph to be sorted.
231
232    Returns:
233        A list that contains all of the graph's nodes in topological order.
234    """
235    result = []
236
237    for node, deps in tuple(dag.items()):
238        for dep in deps:
239            if dep not in dag:
240                dag[dep] = set()
241
242    while dag:
243        current = {node for node, deps in dag.items() if not deps}
244
245        if not current:
246            raise ValueError("Cycle error")
247
248        for node in current:
249            dag.pop(node)
250
251        for deps in dag.values():
252            deps -= current
253
254        result.extend(sorted(current))  # type: ignore
255
256    return result
257
258
259def open_file(file_name: str) -> t.TextIO:
260    """Open a file that may be compressed as gzip and return it in universal newline mode."""
261    with open(file_name, "rb") as f:
262        gzipped = f.read(2) == b"\x1f\x8b"
263
264    if gzipped:
265        import gzip
266
267        return gzip.open(file_name, "rt", newline="")
268
269    return open(file_name, encoding="utf-8", newline="")
270
271
272@contextmanager
273def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
274    """
275    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
276
277    Args:
278        read_csv: A `ReadCSV` function call.
279
280    Yields:
281        A python csv reader.
282    """
283    args = read_csv.expressions
284    file = open_file(read_csv.name)
285
286    delimiter = ","
287    args = iter(arg.name for arg in args)  # type: ignore
288    for k, v in zip(args, args):
289        if k == "delimiter":
290            delimiter = v
291
292    try:
293        import csv as csv_
294
295        yield csv_.reader(file, delimiter=delimiter)
296    finally:
297        file.close()
298
299
300def find_new_name(taken: t.Collection[str], base: str) -> str:
301    """
302    Searches for a new name.
303
304    Args:
305        taken: A collection of taken names.
306        base: Base name to alter.
307
308    Returns:
309        The new, available name.
310    """
311    if base not in taken:
312        return base
313
314    i = 2
315    new = f"{base}_{i}"
316    while new in taken:
317        i += 1
318        new = f"{base}_{i}"
319
320    return new
321
322
323def is_int(text: str) -> bool:
324    try:
325        int(text)
326        return True
327    except ValueError:
328        return False
329
330
331def name_sequence(prefix: str) -> t.Callable[[], str]:
332    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
333    sequence = count()
334    return lambda: f"{prefix}{next(sequence)}"
335
336
337def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
338    """Returns a dictionary created from an object's attributes."""
339    return {
340        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
341        **kwargs,
342    }
343
344
345def split_num_words(
346    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
347) -> t.List[t.Optional[str]]:
348    """
349    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
350
351    Args:
352        value: The value to be split.
353        sep: The value to use to split on.
354        min_num_words: The minimum number of words that are going to be in the result.
355        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
356
357    Examples:
358        >>> split_num_words("db.table", ".", 3)
359        [None, 'db', 'table']
360        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
361        ['db', 'table', None]
362        >>> split_num_words("db.table", ".", 1)
363        ['db', 'table']
364
365    Returns:
366        The list of words returned by `split`, possibly augmented by a number of `None` values.
367    """
368    words = value.split(sep)
369    if fill_from_start:
370        return [None] * (min_num_words - len(words)) + words
371    return words + [None] * (min_num_words - len(words))
372
373
374def is_iterable(value: t.Any) -> bool:
375    """
376    Checks if the value is an iterable, excluding the types `str` and `bytes`.
377
378    Examples:
379        >>> is_iterable([1,2])
380        True
381        >>> is_iterable("test")
382        False
383
384    Args:
385        value: The value to check if it is an iterable.
386
387    Returns:
388        A `bool` value indicating if it is an iterable.
389    """
390    from sqlglot import Expression
391
392    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
393
394
395def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
396    """
397    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
398    type `str` and `bytes` are not regarded as iterables.
399
400    Examples:
401        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
402        [1, 2, 3, 4, 5, 'bla']
403        >>> list(flatten([1, 2, 3]))
404        [1, 2, 3]
405
406    Args:
407        values: The value to be flattened.
408
409    Yields:
410        Non-iterable elements in `values`.
411    """
412    for value in values:
413        if is_iterable(value):
414            yield from flatten(value)
415        else:
416            yield value
417
418
419def dict_depth(d: t.Dict) -> int:
420    """
421    Get the nesting depth of a dictionary.
422
423    Example:
424        >>> dict_depth(None)
425        0
426        >>> dict_depth({})
427        1
428        >>> dict_depth({"a": "b"})
429        1
430        >>> dict_depth({"a": {}})
431        2
432        >>> dict_depth({"a": {"b": {}}})
433        3
434    """
435    try:
436        return 1 + dict_depth(next(iter(d.values())))
437    except AttributeError:
438        # d doesn't have attribute "values"
439        return 0
440    except StopIteration:
441        # d.values() returns an empty sequence
442        return 1
443
444
445def first(it: t.Iterable[T]) -> T:
446    """Returns the first element from an iterable (useful for sets)."""
447    return next(i for i in it)
448
449
450def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
451    """
452    Merges a sequence of ranges, represented as tuples (low, high) whose values
453    belong to some totally-ordered set.
454
455    Example:
456        >>> merge_ranges([(1, 3), (2, 6)])
457        [(1, 6)]
458    """
459    if not ranges:
460        return []
461
462    ranges = sorted(ranges)
463
464    merged = [ranges[0]]
465
466    for start, end in ranges[1:]:
467        last_start, last_end = merged[-1]
468
469        if start <= last_end:
470            merged[-1] = (last_start, max(last_end, end))
471        else:
472            merged.append((start, end))
473
474    return merged
475
476
477def is_iso_date(text: str) -> bool:
478    try:
479        datetime.date.fromisoformat(text)
480        return True
481    except ValueError:
482        return False
483
484
485def is_iso_datetime(text: str) -> bool:
486    try:
487        datetime.datetime.fromisoformat(text)
488        return True
489    except ValueError:
490        return False
491
492
493# Interval units that operate on date components
494DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
495
496
497def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
498    return expression is not None and expression.name.lower() in DATE_UNITS
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

Inherited Members
builtins.property
property
getter
setter
deleter
fget
fset
fdel
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
66def ensure_list(value):
67    """
68    Ensures that a value is a list, otherwise casts or wraps it into one.
69
70    Args:
71        value: The value of interest.
72
73    Returns:
74        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
75    """
76    if value is None:
77        return []
78    if isinstance(value, (list, tuple)):
79        return list(value)
80
81    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 94def ensure_collection(value):
 95    """
 96    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 97
 98    Args:
 99        value: The value of interest.
100
101    Returns:
102        The value if it's a collection, or else the value wrapped in a list.
103    """
104    if value is None:
105        return []
106    return (
107        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
108    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
111def csv(*args: str, sep: str = ", ") -> str:
112    """
113    Formats any number of string arguments as CSV.
114
115    Args:
116        args: The string arguments to format.
117        sep: The argument separator.
118
119    Returns:
120        The arguments formatted as a CSV string.
121    """
122    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
125def subclasses(
126    module_name: str,
127    classes: t.Type | t.Tuple[t.Type, ...],
128    exclude: t.Type | t.Tuple[t.Type, ...] = (),
129) -> t.List[t.Type]:
130    """
131    Returns all subclasses for a collection of classes, possibly excluding some of them.
132
133    Args:
134        module_name: The name of the module to search for subclasses in.
135        classes: Class(es) we want to find the subclasses of.
136        exclude: Class(es) we want to exclude from the returned list.
137
138    Returns:
139        The target subclasses.
140    """
141    return [
142        obj
143        for _, obj in inspect.getmembers(
144            sys.modules[module_name],
145            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
146        )
147    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
150def apply_index_offset(
151    this: exp.Expression,
152    expressions: t.List[E],
153    offset: int,
154) -> t.List[E]:
155    """
156    Applies an offset to a given integer literal expression.
157
158    Args:
159        this: The target of the index.
160        expressions: The expression the offset will be applied to, wrapped in a list.
161        offset: The offset that will be applied.
162
163    Returns:
164        The original expression with the offset applied to it, wrapped in a list. If the provided
165        `expressions` argument contains more than one expression, it's returned unaffected.
166    """
167    if not offset or len(expressions) != 1:
168        return expressions
169
170    expression = expressions[0]
171
172    from sqlglot import exp
173    from sqlglot.optimizer.annotate_types import annotate_types
174    from sqlglot.optimizer.simplify import simplify
175
176    if not this.type:
177        annotate_types(this)
178
179    if t.cast(exp.DataType, this.type).this not in (
180        exp.DataType.Type.UNKNOWN,
181        exp.DataType.Type.ARRAY,
182    ):
183        return expressions
184
185    if not expression.type:
186        annotate_types(expression)
187    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
188        logger.warning("Applying array index offset (%s)", offset)
189        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
190        return [expression]
191
192    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
195def camel_to_snake_case(name: str) -> str:
196    """Converts `name` from camelCase to snake_case and returns the result."""
197    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
200def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
201    """
202    Applies a transformation to a given expression until a fix point is reached.
203
204    Args:
205        expression: The expression to be transformed.
206        func: The transformation to be applied.
207
208    Returns:
209        The transformed expression.
210    """
211    while True:
212        for n, *_ in reversed(tuple(expression.walk())):
213            n._hash = hash(n)
214
215        start = hash(expression)
216        expression = func(expression)
217
218        for n, *_ in expression.walk():
219            n._hash = None
220        if start == hash(expression):
221            break
222
223    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
226def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
227    """
228    Sorts a given directed acyclic graph in topological order.
229
230    Args:
231        dag: The graph to be sorted.
232
233    Returns:
234        A list that contains all of the graph's nodes in topological order.
235    """
236    result = []
237
238    for node, deps in tuple(dag.items()):
239        for dep in deps:
240            if dep not in dag:
241                dag[dep] = set()
242
243    while dag:
244        current = {node for node, deps in dag.items() if not deps}
245
246        if not current:
247            raise ValueError("Cycle error")
248
249        for node in current:
250            dag.pop(node)
251
252        for deps in dag.values():
253            deps -= current
254
255        result.extend(sorted(current))  # type: ignore
256
257    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
260def open_file(file_name: str) -> t.TextIO:
261    """Open a file that may be compressed as gzip and return it in universal newline mode."""
262    with open(file_name, "rb") as f:
263        gzipped = f.read(2) == b"\x1f\x8b"
264
265    if gzipped:
266        import gzip
267
268        return gzip.open(file_name, "rt", newline="")
269
270    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
273@contextmanager
274def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
275    """
276    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
277
278    Args:
279        read_csv: A `ReadCSV` function call.
280
281    Yields:
282        A python csv reader.
283    """
284    args = read_csv.expressions
285    file = open_file(read_csv.name)
286
287    delimiter = ","
288    args = iter(arg.name for arg in args)  # type: ignore
289    for k, v in zip(args, args):
290        if k == "delimiter":
291            delimiter = v
292
293    try:
294        import csv as csv_
295
296        yield csv_.reader(file, delimiter=delimiter)
297    finally:
298        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
301def find_new_name(taken: t.Collection[str], base: str) -> str:
302    """
303    Searches for a new name.
304
305    Args:
306        taken: A collection of taken names.
307        base: Base name to alter.
308
309    Returns:
310        The new, available name.
311    """
312    if base not in taken:
313        return base
314
315    i = 2
316    new = f"{base}_{i}"
317    while new in taken:
318        i += 1
319        new = f"{base}_{i}"
320
321    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
324def is_int(text: str) -> bool:
325    try:
326        int(text)
327        return True
328    except ValueError:
329        return False
def name_sequence(prefix: str) -> Callable[[], str]:
332def name_sequence(prefix: str) -> t.Callable[[], str]:
333    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
334    sequence = count()
335    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
338def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
339    """Returns a dictionary created from an object's attributes."""
340    return {
341        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
342        **kwargs,
343    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
346def split_num_words(
347    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
348) -> t.List[t.Optional[str]]:
349    """
350    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
351
352    Args:
353        value: The value to be split.
354        sep: The value to use to split on.
355        min_num_words: The minimum number of words that are going to be in the result.
356        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
357
358    Examples:
359        >>> split_num_words("db.table", ".", 3)
360        [None, 'db', 'table']
361        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
362        ['db', 'table', None]
363        >>> split_num_words("db.table", ".", 1)
364        ['db', 'table']
365
366    Returns:
367        The list of words returned by `split`, possibly augmented by a number of `None` values.
368    """
369    words = value.split(sep)
370    if fill_from_start:
371        return [None] * (min_num_words - len(words)) + words
372    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
375def is_iterable(value: t.Any) -> bool:
376    """
377    Checks if the value is an iterable, excluding the types `str` and `bytes`.
378
379    Examples:
380        >>> is_iterable([1,2])
381        True
382        >>> is_iterable("test")
383        False
384
385    Args:
386        value: The value to check if it is an iterable.
387
388    Returns:
389        A `bool` value indicating if it is an iterable.
390    """
391    from sqlglot import Expression
392
393    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
396def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
397    """
398    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
399    type `str` and `bytes` are not regarded as iterables.
400
401    Examples:
402        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
403        [1, 2, 3, 4, 5, 'bla']
404        >>> list(flatten([1, 2, 3]))
405        [1, 2, 3]
406
407    Args:
408        values: The value to be flattened.
409
410    Yields:
411        Non-iterable elements in `values`.
412    """
413    for value in values:
414        if is_iterable(value):
415            yield from flatten(value)
416        else:
417            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
420def dict_depth(d: t.Dict) -> int:
421    """
422    Get the nesting depth of a dictionary.
423
424    Example:
425        >>> dict_depth(None)
426        0
427        >>> dict_depth({})
428        1
429        >>> dict_depth({"a": "b"})
430        1
431        >>> dict_depth({"a": {}})
432        2
433        >>> dict_depth({"a": {"b": {}}})
434        3
435    """
436    try:
437        return 1 + dict_depth(next(iter(d.values())))
438    except AttributeError:
439        # d doesn't have attribute "values"
440        return 0
441    except StopIteration:
442        # d.values() returns an empty sequence
443        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
446def first(it: t.Iterable[T]) -> T:
447    """Returns the first element from an iterable (useful for sets)."""
448    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
451def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
452    """
453    Merges a sequence of ranges, represented as tuples (low, high) whose values
454    belong to some totally-ordered set.
455
456    Example:
457        >>> merge_ranges([(1, 3), (2, 6)])
458        [(1, 6)]
459    """
460    if not ranges:
461        return []
462
463    ranges = sorted(ranges)
464
465    merged = [ranges[0]]
466
467    for start, end in ranges[1:]:
468        last_start, last_end = merged[-1]
469
470        if start <= last_end:
471            merged[-1] = (last_start, max(last_end, end))
472        else:
473            merged.append((start, end))
474
475    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
478def is_iso_date(text: str) -> bool:
479    try:
480        datetime.date.fromisoformat(text)
481        return True
482    except ValueError:
483        return False
def is_iso_datetime(text: str) -> bool:
486def is_iso_datetime(text: str) -> bool:
487    try:
488        datetime.datetime.fromisoformat(text)
489        return True
490    except ValueError:
491        return False
DATE_UNITS = {'quarter', 'month', 'day', 'year', 'week', 'year_month'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
498def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
499    return expression is not None and expression.name.lower() in DATE_UNITS