diff options
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r-- | sqlglot/helper.py | 209 |
1 files changed, 165 insertions, 44 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 42965d1..379c2e7 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -1,48 +1,125 @@ +from __future__ import annotations + import inspect import logging import re import sys import typing as t +from collections.abc import Collection from contextlib import contextmanager from copy import copy from enum import Enum +if t.TYPE_CHECKING: + from sqlglot.expressions import Expression, Table + + T = t.TypeVar("T") + E = t.TypeVar("E", bound=Expression) + CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") +PYTHON_VERSION = sys.version_info[:2] logger = logging.getLogger("sqlglot") class AutoName(Enum): - def _generate_next_value_(name, _start, _count, _last_values): + """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name.""" + + def _generate_next_value_(name, _start, _count, _last_values): # type: ignore return name -def list_get(arr, index): +def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: + """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" try: - return arr[index] + return seq[index] except IndexError: return None +@t.overload +def ensure_list(value: t.Collection[T]) -> t.List[T]: + ... + + +@t.overload +def ensure_list(value: T) -> t.List[T]: + ... + + def ensure_list(value): + """ + Ensures that a value is a list, otherwise casts or wraps it into one. + + Args: + value: the value of interest. + + Returns: + The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. + """ if value is None: return [] - return value if isinstance(value, (list, tuple, set)) else [value] + elif isinstance(value, (list, tuple)): + return list(value) + + return [value] + + +@t.overload +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: + ... -def csv(*args, sep=", "): +@t.overload +def ensure_collection(value: T) -> t.Collection[T]: + ... + + +def ensure_collection(value): + """ + Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. + + Args: + value: the value of interest. + + Returns: + The value if it's a collection, or else the value wrapped in a list. + """ + if value is None: + return [] + return ( + value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] + ) + + +def csv(*args, sep: str = ", ") -> str: + """ + Formats any number of string arguments as CSV. + + Args: + args: the string arguments to format. + sep: the argument separator. + + Returns: + The arguments formatted as a CSV string. + """ return sep.join(arg for arg in args if arg) -def subclasses(module_name, classes, exclude=()): +def subclasses( + module_name: str, + classes: t.Type | t.Tuple[t.Type, ...], + exclude: t.Type | t.Tuple[t.Type, ...] = (), +) -> t.List[t.Type]: """ - Returns a list of all subclasses for a specified class set, posibly excluding some of them. + Returns all subclasses for a collection of classes, possibly excluding some of them. Args: - module_name (str): The name of the module to search for subclasses in. - classes (type|tuple[type]): Class(es) we want to find the subclasses of. - exclude (type|tuple[type]): Class(es) we want to exclude from the returned list. + module_name: the name of the module to search for subclasses in. + classes: class(es) we want to find the subclasses of. + exclude: class(es) we want to exclude from the returned list. + Returns: - A list of all the target subclasses. + The target subclasses. """ return [ obj @@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()): ] -def apply_index_offset(expressions, offset): +def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: + """ + Applies an offset to a given integer literal expression. + + Args: + expressions: the expression the offset will be applied to, wrapped in a list. + offset: the offset that will be applied. + + Returns: + The original expression with the offset applied to it, wrapped in a list. If the provided + `expressions` argument contains more than one expressions, it's returned unaffected. + """ if not offset or len(expressions) != 1: return expressions @@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset): logger.warning("Applying array index offset (%s)", offset) expression.args["this"] = str(int(expression.args["this"]) + offset) return [expression] + return expressions -def camel_to_snake_case(name): +def camel_to_snake_case(name: str) -> str: + """Converts `name` from camelCase to snake_case and returns the result.""" return CAMEL_CASE_PATTERN.sub("_", name).upper() -def while_changing(expression, func): +def while_changing( + expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E] +) -> E: + """ + Applies a transformation to a given expression until a fix point is reached. + + Args: + expression: the expression to be transformed. + func: the transformation to be applied. + + Returns: + The transformed expression. + """ while True: start = hash(expression) expression = func(expression) @@ -80,10 +182,19 @@ def while_changing(expression, func): return expression -def tsort(dag): +def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: + """ + Sorts a given directed acyclic graph in topological order. + + Args: + dag: the graph to be sorted. + + Returns: + A list that contains all of the graph's nodes in topological order. + """ result = [] - def visit(node, visited): + def visit(node: T, visited: t.Set[T]) -> None: if node in result: return if node in visited: @@ -103,10 +214,8 @@ def tsort(dag): return result -def open_file(file_name): - """ - Open a file that may be compressed as gzip and return in newline mode. - """ +def open_file(file_name: str) -> t.TextIO: + """Open a file that may be compressed as gzip and return it in universal newline mode.""" with open(file_name, "rb") as f: gzipped = f.read(2) == b"\x1f\x8b" @@ -119,14 +228,14 @@ def open_file(file_name): @contextmanager -def csv_reader(table): +def csv_reader(table: Table) -> t.Any: """ - Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) + Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. Args: - table (exp.Table): A table expression with an anonymous function READ_CSV in it + table: a `Table` expression with an anonymous function `READ_CSV` in it. - Returns: + Yields: A python csv reader. """ file, *args = table.this.expressions @@ -147,13 +256,16 @@ def csv_reader(table): file.close() -def find_new_name(taken, base): +def find_new_name(taken: t.Sequence[str], base: str) -> str: """ Searches for a new name. Args: - taken (Sequence[str]): set of taken names - base (str): base name to alter + taken: a collection of taken names. + base: base name to alter. + + Returns: + The new, available name. """ if base not in taken: return base @@ -163,22 +275,26 @@ def find_new_name(taken, base): while new in taken: i += 1 new = f"{base}_{i}" + return new -def object_to_dict(obj, **kwargs): +def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: + """Returns a dictionary created from an object's attributes.""" return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} -def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]: +def split_num_words( + value: str, sep: str, min_num_words: int, fill_from_start: bool = True +) -> t.List[t.Optional[str]]: """ - Perform a split on a value and return N words as a result with None used for words that don't exist. + Perform a split on a value and return N words as a result with `None` used for words that don't exist. Args: - value: The value to be split - sep: The value to use to split on - min_num_words: The minimum number of words that are going to be in the result - fill_from_start: Indicates that if None values should be inserted at the start or end of the list + value: the value to be split. + sep: the value to use to split on. + min_num_words: the minimum number of words that are going to be in the result. + fill_from_start: indicates that if `None` values should be inserted at the start or end of the list. Examples: >>> split_num_words("db.table", ".", 3) @@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table'] + + Returns: + The list of words returned by `split`, possibly augmented by a number of `None` values. """ words = value.split(sep) if fill_from_start: @@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b def is_iterable(value: t.Any) -> bool: """ - Checks if the value is an iterable but does not include strings and bytes + Checks if the value is an iterable, excluding the types `str` and `bytes`. Examples: >>> is_iterable([1,2]) @@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool: False Args: - value: The value to check if it is an interable + value: the value to check if it is an iterable. - Returns: Bool indicating if it is an iterable + Returns: + A `bool` value indicating if it is an iterable. """ return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) -def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: +def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]: """ - Flattens a list that can contain both iterables and non-iterable elements + Flattens an iterable that can contain both iterable and non-iterable elements. Objects of + type `str` and `bytes` are not regarded as iterables. Examples: - >>> list(flatten([[1, 2], 3])) - [1, 2, 3] + >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) + [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3] Args: - values: The value to be flattened + values: the value to be flattened. - Returns: - Yields non-iterable elements (not including str or byte as iterable) + Yields: + Non-iterable elements in `values`. """ for value in values: if is_iterable(value): |