summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r--sqlglot/helper.py209
1 files changed, 165 insertions, 44 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 42965d1..379c2e7 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -1,48 +1,125 @@
+from __future__ import annotations
+
import inspect
import logging
import re
import sys
import typing as t
+from collections.abc import Collection
from contextlib import contextmanager
from copy import copy
from enum import Enum
+if t.TYPE_CHECKING:
+ from sqlglot.expressions import Expression, Table
+
+ T = t.TypeVar("T")
+ E = t.TypeVar("E", bound=Expression)
+
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
+PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
class AutoName(Enum):
- def _generate_next_value_(name, _start, _count, _last_values):
+ """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
+
+ def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
return name
-def list_get(arr, index):
+def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
+ """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
try:
- return arr[index]
+ return seq[index]
except IndexError:
return None
+@t.overload
+def ensure_list(value: t.Collection[T]) -> t.List[T]:
+ ...
+
+
+@t.overload
+def ensure_list(value: T) -> t.List[T]:
+ ...
+
+
def ensure_list(value):
+ """
+ Ensures that a value is a list, otherwise casts or wraps it into one.
+
+ Args:
+ value: the value of interest.
+
+ Returns:
+ The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
+ """
if value is None:
return []
- return value if isinstance(value, (list, tuple, set)) else [value]
+ elif isinstance(value, (list, tuple)):
+ return list(value)
+
+ return [value]
+
+
+@t.overload
+def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
+ ...
-def csv(*args, sep=", "):
+@t.overload
+def ensure_collection(value: T) -> t.Collection[T]:
+ ...
+
+
+def ensure_collection(value):
+ """
+ Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
+
+ Args:
+ value: the value of interest.
+
+ Returns:
+ The value if it's a collection, or else the value wrapped in a list.
+ """
+ if value is None:
+ return []
+ return (
+ value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
+ )
+
+
+def csv(*args, sep: str = ", ") -> str:
+ """
+ Formats any number of string arguments as CSV.
+
+ Args:
+ args: the string arguments to format.
+ sep: the argument separator.
+
+ Returns:
+ The arguments formatted as a CSV string.
+ """
return sep.join(arg for arg in args if arg)
-def subclasses(module_name, classes, exclude=()):
+def subclasses(
+ module_name: str,
+ classes: t.Type | t.Tuple[t.Type, ...],
+ exclude: t.Type | t.Tuple[t.Type, ...] = (),
+) -> t.List[t.Type]:
"""
- Returns a list of all subclasses for a specified class set, posibly excluding some of them.
+ Returns all subclasses for a collection of classes, possibly excluding some of them.
Args:
- module_name (str): The name of the module to search for subclasses in.
- classes (type|tuple[type]): Class(es) we want to find the subclasses of.
- exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
+ module_name: the name of the module to search for subclasses in.
+ classes: class(es) we want to find the subclasses of.
+ exclude: class(es) we want to exclude from the returned list.
+
Returns:
- A list of all the target subclasses.
+ The target subclasses.
"""
return [
obj
@@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()):
]
-def apply_index_offset(expressions, offset):
+def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
+ """
+ Applies an offset to a given integer literal expression.
+
+ Args:
+ expressions: the expression the offset will be applied to, wrapped in a list.
+ offset: the offset that will be applied.
+
+ Returns:
+ The original expression with the offset applied to it, wrapped in a list. If the provided
+ `expressions` argument contains more than one expressions, it's returned unaffected.
+ """
if not offset or len(expressions) != 1:
return expressions
@@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset):
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset)
return [expression]
+
return expressions
-def camel_to_snake_case(name):
+def camel_to_snake_case(name: str) -> str:
+ """Converts `name` from camelCase to snake_case and returns the result."""
return CAMEL_CASE_PATTERN.sub("_", name).upper()
-def while_changing(expression, func):
+def while_changing(
+ expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
+) -> E:
+ """
+ Applies a transformation to a given expression until a fix point is reached.
+
+ Args:
+ expression: the expression to be transformed.
+ func: the transformation to be applied.
+
+ Returns:
+ The transformed expression.
+ """
while True:
start = hash(expression)
expression = func(expression)
@@ -80,10 +182,19 @@ def while_changing(expression, func):
return expression
-def tsort(dag):
+def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
+ """
+ Sorts a given directed acyclic graph in topological order.
+
+ Args:
+ dag: the graph to be sorted.
+
+ Returns:
+ A list that contains all of the graph's nodes in topological order.
+ """
result = []
- def visit(node, visited):
+ def visit(node: T, visited: t.Set[T]) -> None:
if node in result:
return
if node in visited:
@@ -103,10 +214,8 @@ def tsort(dag):
return result
-def open_file(file_name):
- """
- Open a file that may be compressed as gzip and return in newline mode.
- """
+def open_file(file_name: str) -> t.TextIO:
+ """Open a file that may be compressed as gzip and return it in universal newline mode."""
with open(file_name, "rb") as f:
gzipped = f.read(2) == b"\x1f\x8b"
@@ -119,14 +228,14 @@ def open_file(file_name):
@contextmanager
-def csv_reader(table):
+def csv_reader(table: Table) -> t.Any:
"""
- Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
+ Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
- table (exp.Table): A table expression with an anonymous function READ_CSV in it
+ table: a `Table` expression with an anonymous function `READ_CSV` in it.
- Returns:
+ Yields:
A python csv reader.
"""
file, *args = table.this.expressions
@@ -147,13 +256,16 @@ def csv_reader(table):
file.close()
-def find_new_name(taken, base):
+def find_new_name(taken: t.Sequence[str], base: str) -> str:
"""
Searches for a new name.
Args:
- taken (Sequence[str]): set of taken names
- base (str): base name to alter
+ taken: a collection of taken names.
+ base: base name to alter.
+
+ Returns:
+ The new, available name.
"""
if base not in taken:
return base
@@ -163,22 +275,26 @@ def find_new_name(taken, base):
while new in taken:
i += 1
new = f"{base}_{i}"
+
return new
-def object_to_dict(obj, **kwargs):
+def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
+ """Returns a dictionary created from an object's attributes."""
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
-def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
+def split_num_words(
+ value: str, sep: str, min_num_words: int, fill_from_start: bool = True
+) -> t.List[t.Optional[str]]:
"""
- Perform a split on a value and return N words as a result with None used for words that don't exist.
+ Perform a split on a value and return N words as a result with `None` used for words that don't exist.
Args:
- value: The value to be split
- sep: The value to use to split on
- min_num_words: The minimum number of words that are going to be in the result
- fill_from_start: Indicates that if None values should be inserted at the start or end of the list
+ value: the value to be split.
+ sep: the value to use to split on.
+ min_num_words: the minimum number of words that are going to be in the result.
+ fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
@@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
+
+ Returns:
+ The list of words returned by `split`, possibly augmented by a number of `None` values.
"""
words = value.split(sep)
if fill_from_start:
@@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
def is_iterable(value: t.Any) -> bool:
"""
- Checks if the value is an iterable but does not include strings and bytes
+ Checks if the value is an iterable, excluding the types `str` and `bytes`.
Examples:
>>> is_iterable([1,2])
@@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool:
False
Args:
- value: The value to check if it is an interable
+ value: the value to check if it is an iterable.
- Returns: Bool indicating if it is an iterable
+ Returns:
+ A `bool` value indicating if it is an iterable.
"""
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
-def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
+def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
"""
- Flattens a list that can contain both iterables and non-iterable elements
+ Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
+ type `str` and `bytes` are not regarded as iterables.
Examples:
- >>> list(flatten([[1, 2], 3]))
- [1, 2, 3]
+ >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
+ [1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Args:
- values: The value to be flattened
+ values: the value to be flattened.
- Returns:
- Yields non-iterable elements (not including str or byte as iterable)
+ Yields:
+ Non-iterable elements in `values`.
"""
for value in values:
if is_iterable(value):