From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 71 +++--- sqlglot/__main__.py | 5 +- sqlglot/dataframe/sql/_typing.pyi | 14 +- sqlglot/dataframe/sql/column.py | 46 +++- sqlglot/dataframe/sql/dataframe.py | 158 +++++++++--- sqlglot/dataframe/sql/functions.py | 100 ++++++-- sqlglot/dataframe/sql/group.py | 10 +- sqlglot/dataframe/sql/normalize.py | 13 +- sqlglot/dataframe/sql/readwriter.py | 16 +- sqlglot/dataframe/sql/session.py | 17 +- sqlglot/dataframe/sql/types.py | 6 +- sqlglot/dataframe/sql/window.py | 27 +- sqlglot/dialects/bigquery.py | 57 +++-- sqlglot/dialects/clickhouse.py | 24 +- sqlglot/dialects/databricks.py | 4 +- sqlglot/dialects/dialect.py | 52 ++-- sqlglot/dialects/duckdb.py | 33 +-- sqlglot/dialects/hive.py | 57 +++-- sqlglot/dialects/mysql.py | 329 ++++++++++++++++++++++-- sqlglot/dialects/oracle.py | 20 +- sqlglot/dialects/postgres.py | 25 +- sqlglot/dialects/presto.py | 41 ++- sqlglot/dialects/redshift.py | 13 +- sqlglot/dialects/snowflake.py | 46 ++-- sqlglot/dialects/spark.py | 37 +-- sqlglot/dialects/sqlite.py | 24 +- sqlglot/dialects/starrocks.py | 7 +- sqlglot/dialects/tableau.py | 14 +- sqlglot/dialects/trino.py | 4 +- sqlglot/dialects/tsql.py | 54 ++-- sqlglot/diff.py | 23 +- sqlglot/errors.py | 9 +- sqlglot/executor/context.py | 44 +++- sqlglot/executor/env.py | 4 +- sqlglot/executor/python.py | 190 +++++++------- sqlglot/executor/table.py | 27 +- sqlglot/expressions.py | 258 ++++++++++++------- sqlglot/generator.py | 214 +++++++++++----- sqlglot/helper.py | 209 +++++++++++---- sqlglot/optimizer/annotate_types.py | 131 ++++++++-- sqlglot/optimizer/eliminate_joins.py | 4 +- sqlglot/optimizer/eliminate_subqueries.py | 12 +- sqlglot/optimizer/merge_subqueries.py | 16 +- sqlglot/optimizer/normalize.py | 4 +- sqlglot/optimizer/optimize_joins.py | 6 +- sqlglot/optimizer/optimizer.py | 4 +- sqlglot/optimizer/pushdown_predicates.py | 28 +- sqlglot/optimizer/pushdown_projections.py | 4 +- sqlglot/optimizer/qualify_columns.py | 28 +- sqlglot/optimizer/scope.py | 14 +- sqlglot/optimizer/simplify.py | 12 +- sqlglot/optimizer/unnest_subqueries.py | 14 +- sqlglot/parser.py | 410 +++++++++++++++++++++--------- sqlglot/planner.py | 19 +- sqlglot/py.typed | 0 sqlglot/schema.py | 298 ++++++++++++++-------- sqlglot/time.py | 17 +- sqlglot/tokens.py | 247 +++++++++++------- sqlglot/transforms.py | 42 +-- sqlglot/trie.py | 48 +++- 60 files changed, 2549 insertions(+), 1111 deletions(-) create mode 100644 sqlglot/py.typed (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index d6e18fd..6e67b19 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,5 +1,9 @@ """## Python SQL parser, transpiler and optimizer.""" +from __future__ import annotations + +import typing as t + from sqlglot import expressions as exp from sqlglot.dialects import Dialect, Dialects from sqlglot.diff import diff @@ -20,51 +24,54 @@ from sqlglot.expressions import ( subquery, ) from sqlglot.expressions import table_ as table -from sqlglot.expressions import union +from sqlglot.expressions import to_column, to_table, union from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "9.0.6" +__version__ = "10.0.1" pretty = False schema = MappingSchema() -def parse(sql, read=None, **opts): +def parse( + sql: str, read: t.Optional[str | Dialect] = None, **opts +) -> t.List[t.Optional[Expression]]: """ - Parses the given SQL string into a collection of syntax trees, one per - parsed SQL statement. + Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. Args: - sql (str): the SQL code string to parse. - read (str): the SQL dialect to apply during parsing - (eg. "spark", "hive", "presto", "mysql"). + sql: the SQL code string to parse. + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). **opts: other options. Returns: - typing.List[Expression]: the list of parsed syntax trees. + The resulting syntax tree collection. """ dialect = Dialect.get_or_raise(read)() return dialect.parse(sql, **opts) -def parse_one(sql, read=None, into=None, **opts): +def parse_one( + sql: str, + read: t.Optional[str | Dialect] = None, + into: t.Optional[Expression | str] = None, + **opts, +) -> t.Optional[Expression]: """ - Parses the given SQL string and returns a syntax tree for the first - parsed SQL statement. + Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. Args: - sql (str): the SQL code string to parse. - read (str): the SQL dialect to apply during parsing - (eg. "spark", "hive", "presto", "mysql"). - into (Expression): the SQLGlot Expression to parse into + sql: the SQL code string to parse. + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + into: the SQLGlot Expression to parse into. **opts: other options. Returns: - Expression: the syntax tree for the first parsed statement. + The syntax tree for the first parsed statement. """ dialect = Dialect.get_or_raise(read)() @@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts): return result[0] if result else None -def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts): +def transpile( + sql: str, + read: t.Optional[str | Dialect] = None, + write: t.Optional[str | Dialect] = None, + identity: bool = True, + error_level: t.Optional[ErrorLevel] = None, + **opts, +) -> t.List[str]: """ - Parses the given SQL string using the source dialect and returns a list of SQL strings - transformed to conform to the target dialect. Each string in the returned list represents - a single transformed SQL statement. + Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed + to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement. Args: - sql (str): the SQL code string to transpile. - read (str): the source dialect used to parse the input string - (eg. "spark", "hive", "presto", "mysql"). - write (str): the target dialect into which the input should be transformed - (eg. "spark", "hive", "presto", "mysql"). - identity (bool): if set to True and if the target dialect is not specified - the source dialect will be used as both: the source and the target dialect. - error_level (ErrorLevel): the desired error level of the parser. + sql: the SQL code string to transpile. + read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql"). + write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql"). + identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both: + the source and the target dialect. + error_level: the desired error level of the parser. **opts: other options. Returns: - typing.List[str]: the list of transpiled SQL statements / expressions. + The list of transpiled SQL statements. """ write = write or read if identity else write return [ diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index c0fa380..42a54bc 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -49,7 +49,10 @@ args = parser.parse_args() error_level = sqlglot.ErrorLevel[args.error_level.upper()] if args.parse: - sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)] + sqls = [ + repr(expression) + for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level) + ] else: sqls = sqlglot.transpile( args.sql, diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi index f1a03ea..67c8c09 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -10,11 +10,17 @@ if t.TYPE_CHECKING: from sqlglot.dataframe.sql.types import StructType ColumnLiterals = t.TypeVar( - "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] + "ColumnLiterals", + bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], ) ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) ColumnOrLiteral = t.TypeVar( - "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] + "ColumnOrLiteral", + bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], +) +SchemaInput = t.TypeVar( + "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]] +) +OutputExpressionContainer = t.TypeVar( + "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert] ) -SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]) -OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]) diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index e66aaa8..f9e1c5b 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -18,7 +18,11 @@ class Column: expression = expression.expression # type: ignore elif expression is None or not isinstance(expression, (str, exp.Expression)): expression = self._lit(expression).expression # type: ignore - self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + expression = sqlglot.maybe_parse(expression, dialect="spark") + if expression is None: + raise ValueError(f"Could not parse {expression}") + self.expression: exp.Expression = expression def __repr__(self): return repr(self.expression) @@ -135,21 +139,29 @@ class Column: ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression + k: [Column.ensure_col(x).expression for x in v] + if is_iterable(v) + else Column.ensure_col(v).expression for k, v in kwargs.items() } new_expression = ( callable_expression(**ensure_expression_values) if ensured_column is None - else callable_expression(this=ensured_column.column_expression, **ensure_expression_values) + else callable_expression( + this=ensured_column.column_expression, **ensure_expression_values + ) ) return Column(new_expression) def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + return Column( + klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) + ) def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + return Column( + klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) + ) def unary_op(self, klass: t.Callable, **kwargs) -> Column: return Column(klass(this=self.column_expression, **kwargs)) @@ -188,7 +200,7 @@ class Column: expression.set("table", exp.to_identifier(table_name)) return Column(expression) - def sql(self, **kwargs) -> Column: + def sql(self, **kwargs) -> str: return self.expression.sql(**{"dialect": "spark", **kwargs}) def alias(self, name: str) -> Column: @@ -265,10 +277,14 @@ class Column: ) def like(self, other: str): - return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.Like, expression=self._lit(other).expression + ) def ilike(self, other: str): - return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.ILike, expression=self._lit(other).expression + ) def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos @@ -287,10 +303,18 @@ class Column: lowerBound: t.Union[ColumnOrLiteral], upperBound: t.Union[ColumnOrLiteral], ) -> Column: - lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + lower_bound_exp = ( + self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + ) + upper_bound_exp = ( + self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + ) return Column( - exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + exp.Between( + this=self.column_expression, + low=lower_bound_exp.expression, + high=upper_bound_exp.expression, + ) ) def over(self, window: WindowSpec) -> Column: diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 322dcf2..40cd6c9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func from sqlglot.optimizer.qualify_columns import qualify_columns if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql._typing import ( + ColumnLiterals, + ColumnOrLiteral, + ColumnOrName, + OutputExpressionContainer, + ) from sqlglot.dataframe.sql.session import SparkSession @@ -83,7 +88,9 @@ class DataFrame: return from_exp.alias_or_name table_alias = from_exp.find(exp.TableAlias) if not table_alias: - raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + raise RuntimeError( + f"Could not find an alias name for this expression: {self.expression}" + ) return table_alias.alias_or_name return self.expression.ctes[-1].alias @@ -132,12 +139,16 @@ class DataFrame: cte.set("sequence_id", sequence_id or self.sequence_id) return cte, name - def _ensure_list_of_columns( - self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] - ) -> t.List[Column]: - columns = ensure_list(cols) - columns = Column.ensure_cols(columns) - return columns + @t.overload + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: + ... + + @t.overload + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: + ... + + def _ensure_list_of_columns(self, cols): + return Column.ensure_cols(ensure_list(cols)) def _ensure_and_normalize_cols(self, cols): cols = self._ensure_list_of_columns(cols) @@ -153,10 +164,16 @@ class DataFrame: df = self._resolve_pending_hints() sequence_id = sequence_id or df.sequence_id expression = df.expression.copy() - cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) - new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + cte_expression, cte_name = df._create_cte_from_expression( + expression=expression, sequence_id=sequence_id + ) + new_expression = df._add_ctes_to_expression( + exp.Select(), expression.ctes + [cte_expression] + ) sel_columns = df._get_outer_select_columns(cte_expression) - new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + new_expression = new_expression.from_(cte_name).select( + *[x.alias_or_name for x in sel_columns] + ) return df.copy(expression=new_expression, sequence_id=sequence_id) def _resolve_pending_hints(self) -> DataFrame: @@ -169,16 +186,23 @@ class DataFrame: hint_expression.args.get("expressions").append(hint) df.pending_hints.remove(hint) - join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + join_aliases = { + join_table.alias_or_name + for join_table in get_tables_from_expression_with_join(expression) + } if join_aliases: for hint in df.pending_join_hints: for sequence_id_expression in hint.expressions: sequence_id_or_name = sequence_id_expression.alias_or_name sequence_ids_to_match = [sequence_id_or_name] if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: - sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ + sequence_id_or_name + ] matching_ctes = [ - cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + cte + for cte in reversed(expression.ctes) + if cte.args["sequence_id"] in sequence_ids_to_match ] for matching_cte in matching_ctes: if matching_cte.alias_or_name in join_aliases: @@ -193,9 +217,14 @@ class DataFrame: def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: hint_name = hint_name.upper() hint_expression = ( - exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + exp.JoinHint( + this=hint_name, + expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], + ) if hint_name in JOIN_HINTS - else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + else exp.Anonymous( + this=hint_name, expressions=[parameter.expression for parameter in args] + ) ) new_df = self.copy() new_df.pending_hints.append(hint_expression) @@ -245,7 +274,9 @@ class DataFrame: def _get_select_expressions( self, ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: - select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + select_expressions: t.List[ + t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] + ] = [] main_select_ctes: t.List[exp.CTE] = [] for cte in self.expression.ctes: cache_storage_level = cte.args.get("cache_storage_level") @@ -279,14 +310,19 @@ class DataFrame: cache_table_name = df._create_hash_from_expression(select_expression) cache_table = exp.to_table(cache_table_name) original_alias_name = select_expression.args["cte_alias_name"] - replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore + cache_table_name + ) sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), exp.Literal.string(cache_storage_level), ] - expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + expression = exp.Cache( + this=cache_table, expression=select_expression, lazy=True, options=options + ) # We will drop the "view" if it exists before running the cache table output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) elif expression_type == exp.Create: @@ -305,7 +341,9 @@ class DataFrame: raise ValueError(f"Invalid expression type: {expression_type}") output_expressions.append(expression) - return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + return [ + expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + ] def copy(self, **kwargs) -> DataFrame: return DataFrame(**object_to_dict(self, **kwargs)) @@ -317,7 +355,9 @@ class DataFrame: if self.expression.args.get("joins"): ambiguous_cols = [col for col in cols if not col.column_expression.table] if ambiguous_cols: - join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + join_table_identifiers = [ + x.this for x in get_tables_from_expression_with_join(self.expression) + ] cte_names_in_join = [x.this for x in join_table_identifiers] for ambiguous_col in ambiguous_cols: ctes_with_column = [ @@ -367,14 +407,20 @@ class DataFrame: @operation(Operation.FROM) def join( - self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + self, + other_df: DataFrame, + on: t.Union[str, t.List[str], Column, t.List[Column]], + how: str = "inner", + **kwargs, ) -> DataFrame: other_df = other_df._convert_leaf_to_cte() pre_join_self_latest_cte_name = self.latest_cte_name columns = self._ensure_and_normalize_cols(on) join_type = how.replace("_", " ") if isinstance(columns[0].expression, exp.Column): - join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns + ] join_clause = functools.reduce( lambda x, y: x & y, [ @@ -402,7 +448,9 @@ class DataFrame: for column in self._get_outer_select_columns(other_df) ] column_value_mapping = { - column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql(): column for column in other_columns + self_columns + join_columns } all_columns = [ @@ -410,16 +458,22 @@ class DataFrame: for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} ] new_df = self.copy( - expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + expression=self.expression.join( + other_df.latest_cte_name, on=join_clause.expression, join_type=join_type + ) + ) + new_df.expression = new_df._add_ctes_to_expression( + new_df.expression, other_df.expression.ctes ) - new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) new_df.pending_hints.extend(other_df.pending_hints) new_df = new_df.select.__wrapped__(new_df, *all_columns) return new_df @operation(Operation.ORDER_BY) def orderBy( - self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + self, + *cols: t.Union[str, Column], + ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, ) -> DataFrame: """ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark @@ -429,7 +483,10 @@ class DataFrame: columns = self._ensure_and_normalize_cols(cols) pre_ordered_col_indexes = [ x - for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + for x in [ + i if isinstance(col.expression, exp.Ordered) else None + for i, col in enumerate(columns) + ] if x is not None ] if ascending is None: @@ -478,7 +535,9 @@ class DataFrame: for r_column in r_columns_unused: l_expressions.append(exp.alias_(exp.Null(), r_column)) r_expressions.append(r_column) - r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + r_df = ( + other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + ) l_df = self.copy() if allowMissingColumns: l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) @@ -536,7 +595,9 @@ class DataFrame: f"The minimum num nulls for dropna must be less than or equal to the number of columns. " f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" ) - if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + if_null_checks = [ + F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns + ] nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) num_nulls = nulls_added_together.alias("num_nulls") new_df = new_df.select(num_nulls, append=True) @@ -576,11 +637,15 @@ class DataFrame: value_columns = [lit(value) for value in values] null_replacement_mapping = { - column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + column.alias_or_name: ( + F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) + ) for column, value in zip(columns, value_columns) } null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} - null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + null_replacement_columns = [ + null_replacement_mapping[column.alias_or_name] for column in all_columns + ] new_df = new_df.select(*null_replacement_columns) return new_df @@ -589,12 +654,11 @@ class DataFrame: self, to_replace: t.Union[bool, int, float, str, t.List, t.Dict], value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Union[str, t.List[str]]] = None, + subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, ) -> DataFrame: from sqlglot.dataframe.sql.functions import lit old_values = None - subset = ensure_list(subset) new_df = self.copy() all_columns = self._get_outer_select_columns(new_df.expression) all_column_mapping = {column.alias_or_name: column for column in all_columns} @@ -605,7 +669,9 @@ class DataFrame: new_values = list(to_replace.values()) elif not old_values and isinstance(to_replace, list): assert isinstance(value, list), "value must be a list since the replacements are a list" - assert len(to_replace) == len(value), "the replacements and values must be the same length" + assert len(to_replace) == len( + value + ), "the replacements and values must be the same length" old_values = to_replace new_values = value else: @@ -635,7 +701,9 @@ class DataFrame: def withColumn(self, colName: str, col: Column) -> DataFrame: col = self._ensure_and_normalize_col(col) existing_col_names = self.expression.named_selects - existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + existing_col_index = ( + existing_col_names.index(colName) if colName in existing_col_names else None + ) if existing_col_index: expression = self.expression.copy() expression.expressions[existing_col_index] = col.expression @@ -645,7 +713,11 @@ class DataFrame: @operation(Operation.SELECT) def withColumnRenamed(self, existing: str, new: str): expression = self.expression.copy() - existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + existing_columns = [ + expression + for expression in expression.expressions + if expression.alias_or_name == existing + ] if not existing_columns: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: @@ -674,15 +746,19 @@ class DataFrame: def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: parameter_list = ensure_list(parameters) parameter_columns = ( - self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + self._ensure_list_of_columns(parameter_list) + if parameters + else Column.ensure_cols([self.sequence_id]) ) return self._hint(name, parameter_columns) @operation(Operation.NO_OP) - def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: - num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + def repartition( + self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName + ) -> DataFrame: + num_partition_cols = self._ensure_list_of_columns(numPartitions) columns = self._ensure_and_normalize_cols(cols) - args = num_partitions + columns + args = num_partition_cols + columns return self._hint("repartition", args) @operation(Operation.NO_OP) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bc002e5..dbfb06f 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def when(condition: Column, value: t.Any) -> Column: true_value = value if isinstance(value, Column) else lit(value) - return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) + return Column( + glotexp.Case( + ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)] + ) + ) def asc(col: ColumnOrName) -> Column: @@ -407,7 +411,9 @@ def percentile_approx( return Column.invoke_expression_over_column( col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy ) - return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage)) + return Column.invoke_expression_over_column( + col, glotexp.ApproxQuantile, quantile=lit(percentage) + ) def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: @@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "FACTORIAL") -def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: +def lag( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LAG", offset, default) if offset != 1: @@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu return Column.invoke_anonymous_function(col, "LAG") -def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: +def lead( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LEAD", offset, default) if offset != 1: @@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A return Column.invoke_anonymous_function(col, "LEAD") -def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: +def nth_value( + col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None +) -> Column: if ignoreNulls is not None: raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") if offset != 1: @@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) -def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: +def months_between( + date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None +) -> Column: if roundOff is None: return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) @@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: return Column.invoke_expression_over_column(col, glotexp.UnixToStr) -def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: +def unix_timestamp( + timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None +) -> Column: if format is not None: - return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format)) + return Column.invoke_expression_over_column( + timestamp, glotexp.StrToUnix, format=lit(format) + ) return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) @@ -642,7 +660,9 @@ def window( timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) ) if slideDuration is not None: - return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration) + ) if startTime is not None: return Column.invoke_anonymous_function( timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) @@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)) + return Column.invoke_expression_over_column( + None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols) + ) def decode(col: ColumnOrName, charset: str) -> Column: @@ -768,7 +790,9 @@ def overlay( def sentences( - string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None + string: ColumnOrName, + language: t.Optional[ColumnOrName] = None, + country: t.Optional[ColumnOrName] = None, ) -> Column: if language is not None and country is not None: return Column.invoke_anonymous_function(string, "SENTENCES", language, country) @@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) if pos is not None: - return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos) + return Column.invoke_expression_over_column( + str, glotexp.StrPosition, substr=substr_col, position=pos + ) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) @@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore return Column.invoke_expression_over_column( - None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression + None, + glotexp.VarMap, + keys=array(*cols[::2]).expression, + values=array(*cols[1::2]).expression, ) @@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) + return Column.invoke_expression_over_column( + col, glotexp.ArrayContains, expression=value_col.expression + ) def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) -def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: +def slice( + x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int] +) -> Column: start_col = start if isinstance(start, Column) else lit(start) length_col = length if isinstance(length, Column) else lit(length) return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) -def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: +def array_join( + col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None +) -> Column: if null_replacement is not None: - return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) + return Column.invoke_anonymous_function( + col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement) + ) return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) def concat(*cols: ColumnOrName) -> Column: if len(cols) == 1: return Column.invoke_anonymous_function(cols[0], "CONCAT") - return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) + return Column.invoke_anonymous_function( + cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]] + ) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) -def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: +def sequence( + start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None +) -> Column: if step is not None: return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) return Column.invoke_anonymous_function(start, "SEQUENCE", stop) @@ -1103,12 +1144,15 @@ def aggregate( merge_exp = _get_lambda_from_func(merge) if finish is not None: finish_exp = _get_lambda_from_func(finish) - return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) + return Column.invoke_anonymous_function( + col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) + ) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) def transform( - col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]] + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) @@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) -def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column: +def filter( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) -def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column: +def zip_with( + left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column] +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) @@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]: def _get_lambda_from_func(lambda_expression: t.Callable): - variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames] + variables = [ + glotexp.to_identifier(x, quoted=_lambda_quoted(x)) + for x in lambda_expression.__code__.co_varnames + ] return glotexp.Lambda( this=lambda_expression(*[Column(x) for x in variables]).expression, expressions=variables, diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py index 947aace..ba27c17 100644 --- a/sqlglot/dataframe/sql/group.py +++ b/sqlglot/dataframe/sql/group.py @@ -17,7 +17,9 @@ class GroupedData: self.last_op = last_op self.group_by_cols = group_by_cols - def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]: + def _get_function_applied_columns( + self, func_name: str, cols: t.Tuple[str, ...] + ) -> t.List[Column]: func_name = func_name.lower() return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] @@ -30,9 +32,9 @@ class GroupedData: ) cols = self._df._ensure_and_normalize_cols(columns) - expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select( - *[x.expression for x in self.group_by_cols + cols], append=False - ) + expression = self._df.expression.group_by( + *[x.expression for x in self.group_by_cols] + ).select(*[x.expression for x in self.group_by_cols + cols], append=False) return self._df.copy(expression=expression) def count(self) -> DataFrame: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 1513946..75feba7 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) -def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): +def replace_alias_name_with_cte_name( + spark: SparkSession, expression_context: exp.Select, id: exp.Identifier +): if id.alias_or_name in spark.name_to_sequence_id_mapping: for cte in reversed(expression_context.ctes): if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: @@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name( # id then it keeps that reference. This handles the weird edge case in spark that shouldn't # be common in practice if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: - join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] - ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] + join_table_aliases = [ + x.alias_or_name for x in get_tables_from_expression_with_join(expression_context) + ] + ctes_in_join = [ + cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases + ] if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: assert len(ctes_in_join) == 2 _set_alias_name(id, ctes_in_join[0].alias_or_name) @@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str): def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: - values = ensure_list(values) results = [] for value in values: if isinstance(value, str): diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 4830035..febc664 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -19,12 +19,19 @@ class DataFrameReader: from sqlglot.dataframe.sql.dataframe import DataFrame sqlglot.schema.add_table(tableName) - return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName))) + return DataFrame( + self.spark, + exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), + ) class DataFrameWriter: def __init__( - self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False + self, + df: DataFrame, + spark: t.Optional[SparkSession] = None, + mode: t.Optional[str] = None, + by_name: bool = False, ): self._df = df self._spark = spark or df.spark @@ -33,7 +40,10 @@ class DataFrameWriter: def copy(self, **kwargs) -> DataFrameWriter: return DataFrameWriter( - **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()} + **{ + k[1:] if k.startswith("_") else k: v + for k, v in object_to_dict(self, **kwargs).items() + } ) def sql(self, **kwargs) -> t.List[str]: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 1ea86d1..8cb16ef 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -67,13 +67,20 @@ class SparkSession: data_expressions = [ exp.Tuple( - expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values())) + expressions=list( + map( + lambda x: F.lit(x).expression, + row if not isinstance(row, dict) else row.values(), + ) + ) ) for row in data ] sel_columns = [ - F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression + F.col(name).cast(data_type).alias(name).expression + if data_type is not None + else F.col(name).expression for name, data_type in column_mapping.items() ] @@ -106,10 +113,12 @@ class SparkSession: select_expression.set("with", expression.args.get("with")) expression.set("with", None) del expression.args["expression"] - df = DataFrame(self, select_expression, output_expression_container=expression) + df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore df = df._convert_leaf_to_cte() else: - raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.") + raise ValueError( + "Unknown expression type provided in the SQL. Please create an issue with the SQL." + ) return df @property diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py index dc5c05a..a63e505 100644 --- a/sqlglot/dataframe/sql/types.py +++ b/sqlglot/dataframe/sql/types.py @@ -158,7 +158,11 @@ class MapType(DataType): class StructField(DataType): def __init__( - self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None + self, + name: str, + dataType: DataType, + nullable: bool = True, + metadata: t.Optional[t.Dict[str, t.Any]] = None, ): self.name = name self.dataType = dataType diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index 842f366..c54c07e 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -74,8 +74,13 @@ class WindowSpec: window_spec.expression.args["order"].set("expressions", order_by) return window_spec - def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: - kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None} + def _calc_start_end( + self, start: int, end: int + ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: + kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { + "start_side": None, + "end_side": None, + } if start == Window.currentRow: kwargs["start"] = "CURRENT ROW" else: @@ -83,7 +88,9 @@ class WindowSpec: **kwargs, **{ "start_side": "PRECEDING", - "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression, + "start": "UNBOUNDED" + if start <= Window.unboundedPreceding + else F.lit(start).expression, }, } if end == Window.currentRow: @@ -93,7 +100,9 @@ class WindowSpec: **kwargs, **{ "end_side": "FOLLOWING", - "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression, + "end": "UNBOUNDED" + if end >= Window.unboundedFollowing + else F.lit(end).expression, }, } return kwargs @@ -103,7 +112,10 @@ class WindowSpec: spec = self._calc_start_end(start, end) spec["kind"] = "ROWS" window_spec.expression.set( - "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + "spec", + exp.WindowSpec( + **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} + ), ) return window_spec @@ -112,6 +124,9 @@ class WindowSpec: spec = self._calc_start_end(start, end) spec["kind"] = "RANGE" window_spec.expression.set( - "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + "spec", + exp.WindowSpec( + **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} + ), ) return window_spec diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 62d042e..5bbff9d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,21 +1,21 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, inline_array_sql, no_ilike_sql, rename_func, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=interval.args.get("unit"), ) @@ -23,6 +23,13 @@ def _date_add(expression_class): return func +def _date_trunc(args): + unit = seq_get(args, 1) + if isinstance(unit, exp.Column): + unit = exp.Var(this=unit.name) + return exp.DateTrunc(this=seq_get(args, 0), expression=unit) + + def _date_add_sql(data_type, kind): def func(self, expression): this = self.sql(expression, "this") @@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression): structs = [] for row in rows: aliases = [ - exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) + exp.alias_(value, column_name) + for value, column_name in zip(row, expression.args["alias"].args["columns"]) ] structs.append(exp.Struct(expressions=aliases)) unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) @@ -89,18 +97,19 @@ class BigQuery(Dialect): "%j": "%-j", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = [ (prefix + quote, quote) if prefix else quote for quote in ["'", '"', '"""', "'''"] for prefix in ["", "r", "R"] ] + COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, "GEOGRAPHY": TokenType.GEOGRAPHY, @@ -111,35 +120,40 @@ class BigQuery(Dialect): "WINDOW": TokenType.WINDOW, "NOT DETERMINISTIC": TokenType.VOLATILE, } + KEYWORDS.pop("DIV") - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, + "DATE_TRUNC": _date_trunc, "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), "DATETIME_SUB": _date_add(exp.DatetimeSub), "TIME_SUB": _date_add(exp.TimeSub), "TIMESTAMP_SUB": _date_add(exp.TimestampSub), - "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)), + "PARSE_TIMESTAMP": lambda args: exp.StrToTime( + this=seq_get(args, 1), format=seq_get(args, 0) + ), } NO_PAREN_FUNCTIONS = { - **Parser.NO_PAREN_FUNCTIONS, + **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { - *Parser.NESTED_TYPE_TOKENS, + *parser.Parser.NESTED_TYPE_TOKENS, TokenType.TABLE, } - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), @@ -148,6 +162,7 @@ class BigQuery(Dialect): exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.ILike: no_ilike_sql, + exp.IntDiv: rename_func("DIV"), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -157,11 +172,13 @@ class BigQuery(Dialect): exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, - exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", + exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" + if e.name == "IMMUTABLE" + else "NOT DETERMINISTIC", } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.INT: "INT64", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index f446e6d..332b4c1 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql -from sqlglot.generator import Generator -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.parser import parse_var_map +from sqlglot.tokens import TokenType def _lower_func(sql): @@ -14,11 +15,12 @@ class ClickHouse(Dialect): normalize_functions = None null_ordering = "nulls_are_last" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): + COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "FINAL": TokenType.FINAL, "DATETIME64": TokenType.DATETIME, "INT8": TokenType.TINYINT, @@ -30,9 +32,9 @@ class ClickHouse(Dialect): "TUPLE": TokenType.STRUCT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "MAP": parse_var_map, } @@ -44,11 +46,11 @@ class ClickHouse(Dialect): return this - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.MAP: "Map", @@ -63,7 +65,7 @@ class ClickHouse(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 9dc3c38..2498c62 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.spark import Spark @@ -15,7 +17,7 @@ class Databricks(Spark): class Generator(Spark.Generator): TRANSFORMS = { - **Spark.Generator.TRANSFORMS, + **Spark.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 33985a7..3af08bb 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,8 +1,11 @@ +from __future__ import annotations + +import typing as t from enum import Enum from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import flatten, list_get +from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time from sqlglot.tokens import Tokenizer @@ -32,7 +35,7 @@ class Dialects(str, Enum): class _Dialect(type): - classes = {} + classes: t.Dict[str, Dialect] = {} @classmethod def __getitem__(cls, key): @@ -56,19 +59,30 @@ class _Dialect(type): klass.generator_class = getattr(klass, "Generator", Generator) klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] - - if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class._IDENTIFIERS.items() + )[0] + + if ( + klass.tokenizer_class._BIT_STRINGS + and exp.BitString not in klass.generator_class.TRANSFORMS + ): bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.BitString ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" - if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._HEX_STRINGS + and exp.HexString not in klass.generator_class.TRANSFORMS + ): hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.HexString ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" - if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._BYTE_STRINGS + and exp.ByteString not in klass.generator_class.TRANSFORMS + ): be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.ByteString @@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect): index_offset = 0 unnest_column_only = False alias_post_tablesample = False - normalize_functions = "upper" + normalize_functions: t.Optional[str] = "upper" null_ordering = "nulls_are_small" date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping = {} + time_mapping: t.Dict[str, str] = {} # autofilled quote_start = None @@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect): "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPE, + "escape": self.tokenizer_class.ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression): def if_sql(self, expression): - expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false")) + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) return f"IF({expressions})" @@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None): def _format_time(args): return exp_class( - this=list_get(args, 0), + this=seq_get(args, 0), format=Dialect[dialect].format_time( - list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) ), ) @@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression): "expressions", [e for e in schema.expressions if e not in partitions], ) - prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) + prop.replace( + exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) + ) expression.set("this", schema) return self.create_sql(expression) @@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression): def parse_date_delta(exp_class, unit_mapping=None): def inner_func(args): unit_based = len(args) == 3 - this = list_get(args, 2) if unit_based else list_get(args, 0) - expression = list_get(args, 1) if unit_based else list_get(args, 1) - unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") + this = seq_get(args, 2) if unit_based else seq_get(args, 0) + expression = seq_get(args, 1) if unit_based else seq_get(args, 1) + unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit return exp_class(this=this, expression=expression, unit=unit) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f3ff6d3..781edff 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _unix_to_time(self, expression): @@ -61,11 +61,14 @@ def _sort_array_sql(self, expression): def _sort_array_reverse(args): - return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) + return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE) def _struct_pack_sql(self, expression): - args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions] + args = [ + self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) + for e in expression.expressions + ] return f"STRUCT_PACK({', '.join(args)})" @@ -76,15 +79,15 @@ def _datatype_sql(self, expression): class DuckDB(Dialect): - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, @@ -92,7 +95,7 @@ class DuckDB(Dialect): "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( this=exp.Div( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Literal.number(1000), ) ), @@ -112,11 +115,11 @@ class DuckDB(Dialect): "UNNEST": exp.Explode.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: rename_func("LIST_VALUE"), exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -160,7 +163,7 @@ class DuckDB(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 03049ff..ed7357c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, var_map_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer +from sqlglot.helper import seq_get +from sqlglot.parser import parse_var_map # (FuncType, Multiplier) DATE_DELTA_INTERVAL = { @@ -34,7 +34,9 @@ def _add_date_sql(self, expression): unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( - int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression + int(expression.text("expression")) * multiplier + if expression.expression.is_number + else expression.expression ) modified_increment = exp.Literal.number(modified_increment) return f"{func}({self.format_args(expression.this, modified_increment.this)})" @@ -165,10 +167,10 @@ class Hive(Dialect): dateint_format = "'yyyyMMdd'" time_format = "'yyyy-MM-dd HH:mm:ss'" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] ENCODE = "utf-8" NUMERIC_LITERALS = { @@ -180,40 +182,44 @@ class Hive(Dialect): "BD": "DECIMAL", } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), unit=exp.Literal.string("DAY"), ), "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=list_get(args, 0)), - expression=exp.TsOrDsToDate(this=list_get(args, 1)), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DATE_SUB": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Mul( - this=list_get(args, 1), + this=seq_get(args, 1), expression=exp.Literal.number(-1), ), unit=exp.Literal.string("DAY"), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": lambda args: exp.StrPosition( - this=list_get(args, 1), - substr=list_get(args, 0), - position=list_get(args, 2), + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) ), - "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -226,15 +232,16 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.AnonymousProperty: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 524390f..e742640 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,4 +1,8 @@ -from sqlglot import exp +from __future__ import annotations + +import typing as t + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, no_ilike_sql, @@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType + + +def _show_parser(*args, **kwargs): + def _parse(self): + return self._parse_show_mysql(*args, **kwargs) + + return _parse def _date_trunc_sql(self, expression): - unit = expression.text("unit").lower() + unit = expression.name.lower() - this = self.sql(expression.this) + expr = self.sql(expression.expression) if unit == "day": - return f"DATE({this})" + return f"DATE({expr})" if unit == "week": - concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" elif unit == "month": - concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" date_format = "%Y %c %e" elif unit == "quarter": - concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" date_format = "%Y %c %e" elif unit == "year": - concat = f"CONCAT(YEAR({this}), ' 1 1')" + concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: self.unsupported("Unexpected interval unit: {unit}") - return f"DATE({this})" + return f"DATE({expr})" return f"STR_TO_DATE({concat}, '{date_format}')" def _str_to_date(args): - date_format = MySQL.format_time(list_get(args, 1)) - return exp.StrToDate(this=list_get(args, 0), format=date_format) + date_format = MySQL.format_time(seq_get(args, 1)) + return exp.StrToDate(this=seq_get(args, 0), format=date_format) def _str_to_date_sql(self, expression): @@ -66,9 +75,9 @@ def _trim_sql(self, expression): def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=exp.Literal.string(interval.text("unit").lower()), ) @@ -101,15 +110,16 @@ class MySQL(Dialect): "%l": "%-I", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] + ESCAPES = ["'", "\\"] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -156,20 +166,23 @@ class MySQL(Dialect): "_UTF32": TokenType.INTRODUCER, "_UTF8MB3": TokenType.INTRODUCER, "_UTF8MB4": TokenType.INTRODUCER, + "@@": TokenType.SESSION_PARAMETER, } - class Parser(Parser): + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} + + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -178,15 +191,212 @@ class MySQL(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), } - class Generator(Generator): + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + TokenType.SET: lambda self: self._parse_set(), + } + + SHOW_PARSERS = { + "BINARY LOGS": _show_parser("BINARY LOGS"), + "MASTER LOGS": _show_parser("BINARY LOGS"), + "BINLOG EVENTS": _show_parser("BINLOG EVENTS"), + "CHARACTER SET": _show_parser("CHARACTER SET"), + "CHARSET": _show_parser("CHARACTER SET"), + "COLLATION": _show_parser("COLLATION"), + "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True), + "COLUMNS": _show_parser("COLUMNS", target="FROM"), + "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True), + "CREATE EVENT": _show_parser("CREATE EVENT", target=True), + "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True), + "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True), + "CREATE TABLE": _show_parser("CREATE TABLE", target=True), + "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True), + "CREATE VIEW": _show_parser("CREATE VIEW", target=True), + "DATABASES": _show_parser("DATABASES"), + "ENGINE": _show_parser("ENGINE", target=True), + "STORAGE ENGINES": _show_parser("ENGINES"), + "ENGINES": _show_parser("ENGINES"), + "ERRORS": _show_parser("ERRORS"), + "EVENTS": _show_parser("EVENTS"), + "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True), + "FUNCTION STATUS": _show_parser("FUNCTION STATUS"), + "GRANTS": _show_parser("GRANTS", target="FOR"), + "INDEX": _show_parser("INDEX", target="FROM"), + "MASTER STATUS": _show_parser("MASTER STATUS"), + "OPEN TABLES": _show_parser("OPEN TABLES"), + "PLUGINS": _show_parser("PLUGINS"), + "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True), + "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"), + "PRIVILEGES": _show_parser("PRIVILEGES"), + "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True), + "PROCESSLIST": _show_parser("PROCESSLIST"), + "PROFILE": _show_parser("PROFILE"), + "PROFILES": _show_parser("PROFILES"), + "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"), + "REPLICAS": _show_parser("REPLICAS"), + "SLAVE HOSTS": _show_parser("REPLICAS"), + "REPLICA STATUS": _show_parser("REPLICA STATUS"), + "SLAVE STATUS": _show_parser("REPLICA STATUS"), + "GLOBAL STATUS": _show_parser("STATUS", global_=True), + "SESSION STATUS": _show_parser("STATUS"), + "STATUS": _show_parser("STATUS"), + "TABLE STATUS": _show_parser("TABLE STATUS"), + "FULL TABLES": _show_parser("TABLES", full=True), + "TABLES": _show_parser("TABLES"), + "TRIGGERS": _show_parser("TRIGGERS"), + "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True), + "SESSION VARIABLES": _show_parser("VARIABLES"), + "VARIABLES": _show_parser("VARIABLES"), + "WARNINGS": _show_parser("WARNINGS"), + } + + SET_PARSERS = { + "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), + "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), + "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), + "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), + "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "NAMES": lambda self: self._parse_set_item_names(), + } + + PROFILE_TYPES = { + "ALL", + "BLOCK IO", + "CONTEXT SWITCHES", + "CPU", + "IPC", + "MEMORY", + "PAGE FAULTS", + "SOURCE", + "SWAPS", + } + + def _parse_show_mysql(self, this, target=False, full=None, global_=None): + if target: + if isinstance(target, str): + self._match_text(target) + target_id = self._parse_id_var() + else: + target_id = None + + log = self._parse_string() if self._match_text("IN") else None + + if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: + position = self._parse_number() if self._match_text("FROM") else None + db = None + else: + position = None + db = self._parse_id_var() if self._match_text("FROM") else None + + channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + + like = self._parse_string() if self._match_text("LIKE") else None + where = self._parse_where() + + if this == "PROFILE": + types = self._parse_csv(self._parse_show_profile_type) + query = self._parse_number() if self._match_text("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text("OFFSET") else None + limit = self._parse_number() if self._match_text("LIMIT") else None + else: + types, query = None, None + offset, limit = self._parse_oldstyle_limit() + + mutex = True if self._match_text("MUTEX") else None + mutex = False if self._match_text("STATUS") else mutex + + return self.expression( + exp.Show, + this=this, + target=target_id, + full=full, + log=log, + position=position, + db=db, + channel=channel, + like=like, + where=where, + types=types, + query=query, + offset=offset, + limit=limit, + mutex=mutex, + **{"global": global_}, + ) + + def _parse_show_profile_type(self): + for type_ in self.PROFILE_TYPES: + if self._match_text(*type_.split(" ")): + return exp.Var(this=type_) + return None + + def _parse_oldstyle_limit(self): + limit = None + offset = None + if self._match_text("LIMIT"): + parts = self._parse_csv(self._parse_number) + if len(parts) == 1: + limit = parts[0] + elif len(parts) == 2: + limit = parts[1] + offset = parts[0] + return offset, limit + + def _default_parse_set_item(self): + return self._parse_set_item_assignment(kind=None) + + def _parse_set_item_assignment(self, kind): + left = self._parse_primary() or self._parse_id_var() + if not self._match(TokenType.EQ): + self.raise_error("Expected =") + right = self._parse_statement() or self._parse_id_var() + + this = self.expression( + exp.EQ, + this=left, + expression=right, + ) + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_charset(self, kind): + this = self._parse_string() or self._parse_id_var() + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_names(self): + charset = self._parse_string() or self._parse_id_var() + if self._match_text("COLLATE"): + collate = self._parse_string() or self._parse_id_var() + else: + collate = None + return self.expression( + exp.SetItem, + this=charset, + collate=collate, + kind="NAMES", + ) + + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, @@ -199,6 +409,8 @@ class MySQL(Dialect): exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, + exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), + exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), } ROOT_PROPERTIES = { @@ -209,4 +421,69 @@ class MySQL(Dialect): exp.SchemaCommentProperty, } - WITH_PROPERTIES = {} + WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() + + def show_sql(self, expression): + this = f" {expression.name}" + full = " FULL" if expression.args.get("full") else "" + global_ = " GLOBAL" if expression.args.get("global") else "" + + target = self.sql(expression, "target") + target = f" {target}" if target else "" + if expression.name in {"COLUMNS", "INDEX"}: + target = f" FROM{target}" + elif expression.name == "GRANTS": + target = f" FOR{target}" + + db = self._prefixed_sql("FROM", expression, "db") + + like = self._prefixed_sql("LIKE", expression, "like") + where = self.sql(expression, "where") + + types = self.expressions(expression, key="types") + types = f" {types}" if types else types + query = self._prefixed_sql("FOR QUERY", expression, "query") + + if expression.name == "PROFILE": + offset = self._prefixed_sql("OFFSET", expression, "offset") + limit = self._prefixed_sql("LIMIT", expression, "limit") + else: + offset = "" + limit = self._oldstyle_limit_sql(expression) + + log = self._prefixed_sql("IN", expression, "log") + position = self._prefixed_sql("FROM", expression, "position") + + channel = self._prefixed_sql("FOR CHANNEL", expression, "channel") + + if expression.name == "ENGINE": + mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS" + else: + mutex_or_status = "" + + return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" + + def _prefixed_sql(self, prefix, expression, arg): + sql = self.sql(expression, arg) + if not sql: + return "" + return f" {prefix} {sql}" + + def _oldstyle_limit_sql(self, expression): + limit = self.sql(expression, "limit") + offset = self.sql(expression, "offset") + if limit: + limit_offset = f"{offset}, {limit}" if offset else limit + return f" LIMIT {limit_offset}" + return "" + + def setitem_sql(self, expression): + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + return f"{kind}{this}{collate}" + + def set_sql(self, expression): + return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 144dba5..3bc1109 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,8 +1,9 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql -from sqlglot.generator import Generator from sqlglot.helper import csv -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType def _limit_sql(self, expression): @@ -36,9 +37,9 @@ class Oracle(Dialect): "YYYY": "%Y", # 2015 } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -49,11 +50,12 @@ class Oracle(Dialect): exp.DataType.Type.NVARCHAR: "NVARCHAR2", exp.DataType.Type.TEXT: "CLOB", exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -86,9 +88,9 @@ class Oracle(Dialect): def table_sql(self, expression): return super().table_sql(expression, sep=" ") - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, "NVARCHAR2": TokenType.NVARCHAR, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 459e926..553a73b 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, str_position_sql, ) -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess @@ -160,12 +160,12 @@ class Postgres(Dialect): "YYYY": "%Y", # 2015 } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, "COMMENT ON": TokenType.COMMENT_ON, @@ -179,31 +179,32 @@ class Postgres(Dialect): } QUOTES = ["'", "$$"] SINGLE_TOKENS = { - **Tokenizer.SINGLE_TOKENS, + **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", + exp.DataType.Type.VARBINARY: "BYTEA", exp.DataType.Type.DATETIME: "TIMESTAMP", } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ColumnDef: preprocess( [ _auto_increment_to_serial, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index a2d392c..11ea778 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, ) from sqlglot.dialects.mysql import MySQL -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _approx_distinct_sql(self, expression): @@ -110,30 +110,29 @@ class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" time_format = "'%Y-%m-%d %H:%i:%S'" - time_mapping = MySQL.time_mapping + time_mapping = MySQL.time_mapping # type: ignore - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, - "VARBINARY": TokenType.BINARY, + **tokens.Tokenizer.KEYWORDS, "ROW": TokenType.STRUCT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( - this=list_get(args, 2), - expression=list_get(args, 1), - unit=list_get(args, 0), + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), ), "DATE_DIFF": lambda args: exp.DateDiff( - this=list_get(args, 2), - expression=list_get(args, 1), - unit=list_get(args, 0), + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), @@ -143,7 +142,7 @@ class Presto(Dialect): "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -159,7 +158,7 @@ class Presto(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -169,8 +168,8 @@ class Presto(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index e1f7b78..a9b12fb 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { - **Postgres.time_mapping, + **Postgres.time_mapping, # type: ignore "MON": "%b", "HH": "%H", } class Tokenizer(Postgres.Tokenizer): - ESCAPE = "\\" + ESCAPES = ["\\"] KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, + **Postgres.Tokenizer.KEYWORDS, # type: ignore "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, - "VARBYTE": TokenType.BINARY, + "VARBYTE": TokenType.VARBINARY, "SIMILAR TO": TokenType.SIMILAR_TO, } class Generator(Postgres.Generator): TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, + **Postgres.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BINARY: "VARBYTE", + exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 3b97e6d..d1aaded 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import ( rename_func, ) from sqlglot.expressions import Literal -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _check_int(s): @@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args): # case: [ , ] if second_arg.name not in ["0", "3", "9"]: - raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9") + raise ValueError( + f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" + ) if second_arg.name == "0": timescale = exp.UnixToTime.SECONDS @@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime(this=first_arg, scale=timescale) - first_arg = list_get(args, 0) + first_arg = seq_get(args, 0) if not isinstance(first_arg, Literal): # case: return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) @@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime.from_arg_list(args) -def _unix_to_time(self, expression): +def _unix_to_time_sql(self, expression): scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -132,9 +134,9 @@ class Snowflake(Dialect): "ff6": "%f", } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "IFF": exp.If.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, @@ -143,18 +145,18 @@ class Snowflake(Dialect): } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, } FUNC_TOKENS = { - *Parser.FUNC_TOKENS, + *parser.Parser.FUNC_TOKENS, TokenType.RLIKE, TokenType.TABLE, } COLUMN_OPERATORS = { - **Parser.COLUMN_OPERATORS, + **parser.Parser.COLUMN_OPERATORS, # type: ignore TokenType.COLON: lambda self, this, path: self.expression( exp.Bracket, this=this, @@ -163,21 +165,21 @@ class Snowflake(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPE = "\\" + ESCAPES = ["\\"] SINGLE_TOKENS = { - **Tokenizer.SINGLE_TOKENS, + **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, } KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "QUALIFY": TokenType.QUALIFY, "DOUBLE PRECISION": TokenType.DOUBLE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, @@ -187,15 +189,15 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, } - class Generator(Generator): + class Generator(generator.Generator): CREATE_TRANSIENT = True TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.If: rename_func("IFF"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, + exp.UnixToTime: _unix_to_time_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Array: inline_array_sql, exp.StrPosition: rename_func("POSITION"), @@ -204,7 +206,7 @@ class Snowflake(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 572f411..4e404b8 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func from sqlglot.dialects.hive import Hive -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get def _create_sql(self, e): @@ -46,36 +47,36 @@ def _unix_to_time(self, expression): class Spark(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, + **Hive.Parser.FUNCTIONS, # type: ignore "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Literal.number(1), - length=list_get(args, 1), + length=seq_get(args, 1), ), "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "RIGHT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Sub( - this=exp.Length(this=list_get(args, 0)), - expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), + this=exp.Length(this=seq_get(args, 0)), + expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), ), - length=list_get(args, 1), + length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -88,14 +89,14 @@ class Spark(Hive): class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, + **Hive.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } TRANSFORMS = { - **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}}, + **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", @@ -114,6 +115,8 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), } + TRANSFORMS.pop(exp.ArraySort) + TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 62b7617..8c9fb76 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, ) -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType class SQLite(Dialect): - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, - "VARBINARY": TokenType.BINARY, + **tokens.Tokenizer.KEYWORDS, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -46,6 +45,7 @@ class SQLite(Dialect): exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", } TOKEN_MAPPING = { @@ -53,7 +53,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 0cba6fe..3519c09 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL class StarRocks(MySQL): - class Generator(MySQL.Generator): + class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", @@ -13,7 +15,7 @@ class StarRocks(MySQL): } TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, + **MySQL.Generator.TRANSFORMS, # type: ignore exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), @@ -22,3 +24,4 @@ class StarRocks(MySQL): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), } + TRANSFORMS.pop(exp.DateTrunc) diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 45aa041..63e7275 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,7 +1,7 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser from sqlglot.dialects.dialect import Dialect -from sqlglot.generator import Generator -from sqlglot.parser import Parser def _if_sql(self, expression): @@ -20,17 +20,17 @@ def _count_sql(self, expression): class Tableau(Dialect): - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.If: _if_sql, exp.Coalesce: _coalesce_sql, exp.Count: _count_sql, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 9a6f7fe..c7b34fe 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.presto import Presto @@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): class Generator(Presto.Generator): TRANSFORMS = { - **Presto.Generator.TRANSFORMS, + **Presto.Generator.TRANSFORMS, # type: ignore exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 0f93c75..a233d4b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,15 +1,22 @@ +from __future__ import annotations + import re -from sqlglot import exp +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.expressions import DataType -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType -FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"} +FULL_FORMAT_TIME_MAPPING = { + "weekday": "%A", + "dw": "%A", + "w": "%A", + "month": "%B", + "mm": "%B", + "m": "%B", +} DATE_DELTA_INTERVAL = { "year": "year", "yyyy": "year", @@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( - this=list_get(args, 1), + this=seq_get(args, 1), format=exp.Literal.string( format_time( - list_get(args, 0).name or (TSQL.time_format if default is True else default), - {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, + seq_get(args, 0).name or (TSQL.time_format if default is True else default), + {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} + if full_format_mapping + else TSQL.time_mapping, ) ), ) @@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def parse_format(args): - fmt = list_get(args, 1) + fmt = seq_get(args, 1) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) if number_fmt: - return exp.NumberToStr(this=list_get(args, 0), format=fmt) + return exp.NumberToStr(this=seq_get(args, 0), format=fmt) return exp.TimeToStr( - this=list_get(args, 0), + this=seq_get(args, 0), format=exp.Literal.string( format_time(fmt.name, TSQL.format_time_mapping) if len(fmt.name) == 1 @@ -188,11 +197,11 @@ class TSQL(Dialect): "Y": "%a %Y", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "BIT": TokenType.BOOLEAN, "REAL": TokenType.FLOAT, "NTEXT": TokenType.TEXT, @@ -200,7 +209,6 @@ class TSQL(Dialect): "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "TIME": TokenType.TIMESTAMP, - "VARBINARY": TokenType.BINARY, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "SMALLMONEY": TokenType.SMALLMONEY, @@ -213,9 +221,9 @@ class TSQL(Dialect): "TOP": TokenType.TOP, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), @@ -243,14 +251,16 @@ class TSQL(Dialect): this = self._parse_column() # Retrieve length of datatype and override to default if not specified - if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: + if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) # Check whether a conversion with format is applicable if self._match(TokenType.COMMA): format_val = self._parse_number().name if format_val not in TSQL.convert_format_mapping: - raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}") + raise ValueError( + f"CONVERT function at T-SQL does not support format style {format_val}" + ) format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) # Check whether the convert entails a string to date format @@ -272,9 +282,9 @@ class TSQL(Dialect): # Entails a simple cast without any format requirement return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", @@ -283,7 +293,7 @@ class TSQL(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 0567c12..2d959ab 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -4,7 +4,7 @@ from heapq import heappop, heappush from sqlglot import Dialect from sqlglot import expressions as exp -from sqlglot.helper import ensure_list +from sqlglot.helper import ensure_collection @dataclass(frozen=True) @@ -116,7 +116,9 @@ class ChangeDistiller: source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: - edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set)) + edit_script.extend( + self._generate_move_edits(source_node, target_node, matching_set) + ) edit_script.append(Keep(source_node, target_node)) else: edit_script.append(Update(source_node, target_node)) @@ -158,13 +160,16 @@ class ChangeDistiller: max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) if max_leaves_num: common_leaves_num = sum( - 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 + for s, t in leaves_matching_set ) leaf_similarity_score = common_leaves_num / max_leaves_num else: leaf_similarity_score = 0.0 - adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 + adjusted_t = ( + self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 + ) if leaf_similarity_score >= 0.8 or ( leaf_similarity_score >= adjusted_t @@ -201,7 +206,10 @@ class ChangeDistiller: matching_set = set() while candidate_matchings: _, _, source_leaf, target_leaf = heappop(candidate_matchings) - if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes: + if ( + id(source_leaf) in self._unmatched_source_nodes + and id(target_leaf) in self._unmatched_target_nodes + ): matching_set.add((id(source_leaf), id(target_leaf))) self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_target_nodes.remove(id(target_leaf)) @@ -241,8 +249,7 @@ def _get_leaves(expression): has_child_exprs = False for a in expression.args.values(): - nodes = ensure_list(a) - for node in nodes: + for node in ensure_collection(a): if isinstance(node, exp.Expression): has_child_exprs = True yield from _get_leaves(node) @@ -268,7 +275,7 @@ def _expression_only_args(expression): args = [] if expression: for a in expression.args.values(): - args.extend(ensure_list(a)) + args.extend(ensure_collection(a)) return [a for a in args if isinstance(a, exp.Expression)] diff --git a/sqlglot/errors.py b/sqlglot/errors.py index 89aa935..2ef908f 100644 --- a/sqlglot/errors.py +++ b/sqlglot/errors.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from enum import auto from sqlglot.helper import AutoName @@ -30,7 +33,11 @@ class OptimizeError(SqlglotError): pass -def concat_errors(errors, maximum): +class SchemaError(SqlglotError): + pass + + +def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: msg = [str(e) for e in errors[:maximum]] remaining = len(errors) - maximum if remaining > 0: diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index d265a2c..393347b 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,6 +19,7 @@ class Context: env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables + self._table = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -29,8 +30,27 @@ class Context: def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) + @property + def table(self): + if self._table is None: + self._table = list(self.tables.values())[0] + for other in self.tables.values(): + if self._table.columns != other.columns: + raise Exception(f"Columns are different.") + if len(self._table.rows) != len(other.rows): + raise Exception(f"Rows are different.") + return self._table + + @property + def columns(self): + return self.table.columns + def __iter__(self): - return self.table_iter(list(self.tables)[0]) + self.env["scope"] = self.row_readers + for i in range(len(self.table.rows)): + for table in self.tables.values(): + reader = table[i] + yield reader, self def table_iter(self, table): self.env["scope"] = self.row_readers @@ -38,8 +58,8 @@ class Context: for reader in self.tables[table]: yield reader, self - def sort(self, table, key): - table = self.tables[table] + def sort(self, key): + table = self.table def sort_key(row): table.reader.row = row @@ -47,20 +67,20 @@ class Context: table.rows.sort(key=sort_key) - def set_row(self, table, row): - self.row_readers[table].row = row + def set_row(self, row): + for table in self.tables.values(): + table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, table, index): - self.row_readers[table].row = self.tables[table].rows[index] + def set_index(self, index): + for table in self.tables.values(): + table[index] self.env["scope"] = self.row_readers - def set_range(self, table, start, end): - self.range_readers[table].range = range(start, end) + def set_range(self, start, end): + for name in self.tables: + self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __getitem__(self, table): - return self.env["scope"][table] - def __contains__(self, table): return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 9c49dd1..bbe6c81 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -2,6 +2,8 @@ import datetime import re import statistics +from sqlglot.helper import PYTHON_VERSION + class reverse_key: def __init__(self, obj): @@ -25,7 +27,7 @@ ENV = { "str": str, "desc": reverse_key, "SUM": sum, - "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, + "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore "COUNT": lambda acc: sum(1 for e in acc if e is not None), "MAX": max, "MIN": min, diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index fcb016b..7d1db32 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -1,15 +1,14 @@ import ast import collections import itertools +import math -from sqlglot import exp, planner +from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.executor.context import Context from sqlglot.executor.env import ENV from sqlglot.executor.table import Table -from sqlglot.generator import Generator from sqlglot.helper import csv_reader -from sqlglot.tokens import Tokenizer class PythonExecutor: @@ -26,7 +25,11 @@ class PythonExecutor: while queue: node = queue.pop() context = self.context( - {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } ) running.add(node) @@ -76,13 +79,10 @@ class PythonExecutor: return Table(expression.alias_or_name for expression in expressions) def scan(self, step, context): - if hasattr(step, "source"): - source = step.source + source = step.source - if isinstance(source, exp.Expression): - source = source.name or source.alias - else: - source = step.name + if isinstance(source, exp.Expression): + source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) @@ -96,14 +96,12 @@ class PythonExecutor: if projections: sink = self.table(step.projections) - elif source in context: - sink = Table(context[source].columns) else: sink = None for reader, ctx in table_iter: if sink is None: - sink = Table(ctx[source].columns) + sink = Table(reader.columns) if condition and not ctx.eval(condition): continue @@ -135,98 +133,79 @@ class PythonExecutor: types.append(type(ast.literal_eval(v))) except (ValueError, SyntaxError): types.append(str) - context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) - yield context[alias], context + context.set_row(tuple(t(v) for t, v in zip(types, row))) + yield context.table.reader, context def join(self, step, context): source = step.name - join_context = self.context({source: context.tables[source]}) - - def merge_context(ctx, table): - # create a new context where all existing tables are mapped to a new one - return self.context({name: table for name in ctx.tables}) + source_table = context.tables[source] + source_context = self.context({source: source_table}) + column_ranges = {source: range(0, len(source_table.columns))} for name, join in step.joins.items(): - join_context = self.context({**join_context.tables, name: context.tables[name]}) + table = context.tables[name] + start = max(r.stop for r in column_ranges.values()) + column_ranges[name] = range(start, len(table.columns) + start) + join_context = self.context({name: table}) if join.get("source_key"): - table = self.hash_join(join, source, name, join_context) + table = self.hash_join(join, source_context, join_context) else: - table = self.nested_loop_join(join, source, name, join_context) + table = self.nested_loop_join(join, source_context, join_context) - join_context = merge_context(join_context, table) - - # apply projections or conditions - context = self.scan(step, join_context) + source_context = self.context( + { + name: Table(table.columns, table.rows, column_range) + for name, column_range in column_ranges.items() + } + ) - # use the scan context since it returns a single table - # otherwise there are no projections so all other tables are still in scope - if step.projections: - return context + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) - return merge_context(join_context, context.tables[source]) + if not condition or not projections: + return source_context - def nested_loop_join(self, _join, a, b, context): - table = Table(context.tables[a].columns + context.tables[b].columns) + sink = self.table(step.projections if projections else source_context.columns) - for reader_a, _ in context.table_iter(a): - for reader_b, _ in context.table_iter(b): - table.append(reader_a.row + reader_b.row) + for reader, ctx in join_context: + if condition and not ctx.eval(condition): + continue - return table + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) - def hash_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) + if len(sink) >= step.limit: + break - results = collections.defaultdict(lambda: ([], [])) + return self.context({step.name: sink}) - for reader, ctx in context.table_iter(a): - results[ctx.eval_tuple(a_key)][0].append(reader.row) - for reader, ctx in context.table_iter(b): - results[ctx.eval_tuple(b_key)][1].append(reader.row) + def nested_loop_join(self, _join, source_context, join_context): + table = Table(source_context.columns + join_context.columns) - table = Table(context.tables[a].columns + context.tables[b].columns) - for a_group, b_group in results.values(): - for a_row, b_row in itertools.product(a_group, b_group): - table.append(a_row + b_row) + for reader_a, _ in source_context: + for reader_b, _ in join_context: + table.append(reader_a.row + reader_b.row) return table - def sort_merge_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) - - context.sort(a, a_key) - context.sort(b, b_key) - - a_i = 0 - b_i = 0 - a_n = len(context.tables[a]) - b_n = len(context.tables[b]) - - table = Table(context.tables[a].columns + context.tables[b].columns) - - def get_key(source, key, i): - context.set_index(source, i) - return context.eval_tuple(key) + def hash_join(self, join, source_context, join_context): + source_key = self.generate_tuple(join["source_key"]) + join_key = self.generate_tuple(join["join_key"]) - while a_i < a_n and b_i < b_n: - key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) - - a_group = [] - - while a_i < a_n and key == get_key(a, a_key, a_i): - a_group.append(context[a].row) - a_i += 1 + results = collections.defaultdict(lambda: ([], [])) - b_group = [] + for reader, ctx in source_context: + results[ctx.eval_tuple(source_key)][0].append(reader.row) + for reader, ctx in join_context: + results[ctx.eval_tuple(join_key)][1].append(reader.row) - while b_i < b_n and key == get_key(b, b_key, b_i): - b_group.append(context[b].row) - b_i += 1 + table = Table(source_context.columns + join_context.columns) + for a_group, b_group in results.values(): for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) @@ -238,16 +217,18 @@ class PythonExecutor: aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) - context.sort(source, group_by) - - if step.operands: + if operands: source_table = context.tables[source] operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) - context = self.context({source: operand_table}) + context = self.context( + {None: operand_table, **{table: operand_table for table in context.tables}} + ) + + context.sort(group_by) group = None start = 0 @@ -256,15 +237,15 @@ class PythonExecutor: table = self.table(step.group + step.aggregations) for i in range(length): - context.set_index(source, i) + context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 if i == length - 1: - context.set_range(source, start, end - 1) + context.set_range(start, end - 1) elif key != group: - context.set_range(source, start, end - 2) + context.set_range(start, end - 2) else: continue @@ -272,13 +253,32 @@ class PythonExecutor: group = key start = end - 2 - return self.scan(step, self.context({source: table})) + context = self.context({step.name: table, **{name: table for name in context.tables}}) + + if step.projections: + return self.scan(step, context) + return context def sort(self, step, context): - table = list(context.tables)[0] - key = self.generate_tuple(step.key) - context.sort(table, key) - return self.scan(step, context) + projections = self.generate_tuple(step.projections) + + sink = self.table(step.projections) + + for reader, ctx in context: + sink.append(ctx.eval_tuple(projections)) + + context = self.context( + { + None: sink, + **{table: sink for table in context.tables}, + } + ) + context.sort(self.generate_tuple(step.key)) + + if not math.isinf(step.limit): + context.table.rows = context.table.rows[0 : step.limit] + + return self.context({step.name: context.table}) def _cast_py(self, expression): @@ -293,7 +293,7 @@ def _cast_py(self, expression): def _column_py(self, expression): - table = self.sql(expression, "table") + table = self.sql(expression, "table") or None this = self.sql(expression, "this") return f"scope[{table}][{this}]" @@ -319,10 +319,10 @@ def _ordered_py(self, expression): class Python(Dialect): - class Tokenizer(Tokenizer): - ESCAPE = "\\" + class Tokenizer(tokens.Tokenizer): + ESCAPES = ["\\"] - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 80674cb..6796740 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,10 +1,12 @@ class Table: - def __init__(self, *columns, rows=None): - self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + def __init__(self, columns, rows=None, column_range=None): + self.columns = tuple(columns) + self.column_range = column_range + self.reader = RowReader(self.columns, self.column_range) + self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) - self.reader = RowReader(self.columns) self.range_reader = RangeReader(self) def append(self, row): @@ -29,15 +31,22 @@ class Table: return self.reader def __repr__(self): - widths = {column: len(column) for column in self.columns} - lines = [" ".join(column for column in self.columns)] + columns = tuple( + column + for i, column in enumerate(self.columns) + if not self.column_range or i in self.column_range + ) + widths = {column: len(column) for column in columns} + lines = [" ".join(column for column in columns)] for i, row in enumerate(self): if i > 10: break lines.append( - " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns + ) ) return "\n".join(lines) @@ -70,8 +79,10 @@ class RangeReader: class RowReader: - def __init__(self, columns): - self.columns = {column: i for i, column in enumerate(columns)} + def __init__(self, columns, column_range=None): + self.columns = { + column: i for i, column in enumerate(columns) if not column_range or i in column_range + } self.row = None def __getitem__(self, column): diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1691d85..57a2c88 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import datetime import numbers import re +import typing as t from collections import deque from copy import deepcopy from enum import auto @@ -9,12 +12,15 @@ from sqlglot.errors import ParseError from sqlglot.helper import ( AutoName, camel_to_snake_case, - ensure_list, - list_get, + ensure_collection, + seq_get, split_num_words, subclasses, ) +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import Dialect + class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -35,27 +41,30 @@ class Expression(metaclass=_Expression): or optional (False). """ - key = None + key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type") + __slots__ = ("args", "parent", "arg_key", "type", "comment") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None + self.comment = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) is type(other) and _norm_args(self) == _norm_args(other) - def __hash__(self): + def __hash__(self) -> int: return hash( ( self.key, - tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), + tuple( + (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items() + ), ) ) @@ -79,6 +88,19 @@ class Expression(metaclass=_Expression): return field.this return "" + def find_comment(self, key: str) -> str: + """ + Finds the comment that is attached to a specified child node. + + Args: + key: the key of the target child node (e.g. "this", "expression", etc). + + Returns: + The comment attached to the child node, or the empty string, if it doesn't exist. + """ + field = self.args.get(key) + return field.comment if isinstance(field, Expression) else "" + @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -114,7 +136,10 @@ class Expression(metaclass=_Expression): return self.alias or self.name def __deepcopy__(self, memo): - return self.__class__(**deepcopy(self.args)) + copy = self.__class__(**deepcopy(self.args)) + copy.comment = self.comment + copy.type = self.type + return copy def copy(self): new = deepcopy(self) @@ -249,9 +274,7 @@ class Expression(metaclass=_Expression): return for k, v in self.args.items(): - nodes = ensure_list(v) - - for node in nodes: + for node in ensure_collection(v): if isinstance(node, Expression): yield from node.dfs(self, k, prune) @@ -274,9 +297,7 @@ class Expression(metaclass=_Expression): if isinstance(item, Expression): for k, v in item.args.items(): - nodes = ensure_list(v) - - for node in nodes: + for node in ensure_collection(v): if isinstance(node, Expression): queue.append((node, item, k)) @@ -319,7 +340,7 @@ class Expression(metaclass=_Expression): def __repr__(self): return self.to_s() - def sql(self, dialect=None, **opts): + def sql(self, dialect: Dialect | str | None = None, **opts) -> str: """ Returns SQL string representation of this tree. @@ -335,7 +356,7 @@ class Expression(metaclass=_Expression): return Dialect.get_or_raise(dialect)().generate(self, **opts) - def to_s(self, hide_missing=True, level=0): + def to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" indent += "".join([" "] * level) left = f"({self.key.upper()} " @@ -343,11 +364,13 @@ class Expression(metaclass=_Expression): args = { k: ", ".join( v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) - for v in ensure_list(vs) + for v in ensure_collection(vs) if v is not None ) for k, vs in self.args.items() } + args["comment"] = self.comment + args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} right = ", ".join(f"{k}: {v}" for k, v in args.items()) @@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable): pass -class Annotation(Expression): - arg_types = { - "this": True, - "expression": True, - } - - @property - def alias(self): - return self.expression.alias_or_name - - class Cache(Expression): arg_types = { "with": False, @@ -623,6 +635,38 @@ class Describe(Expression): pass +class Set(Expression): + arg_types = {"expressions": True} + + +class SetItem(Expression): + arg_types = { + "this": True, + "kind": False, + "collate": False, # MySQL SET NAMES statement + } + + +class Show(Expression): + arg_types = { + "this": True, + "target": False, + "offset": False, + "limit": False, + "like": False, + "where": False, + "db": False, + "full": False, + "mutex": False, + "query": False, + "channel": False, + "global": False, + "log": False, + "position": False, + "types": False, + } + + class UserDefinedFunction(Expression): arg_types = {"this": True, "expressions": False} @@ -864,18 +908,20 @@ class Literal(Condition): def __eq__(self, other): return ( - isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] + isinstance(other, Literal) + and self.this == other.this + and self.args["is_string"] == other.args["is_string"] ) def __hash__(self): return hash((self.key, self.this, self.args["is_string"])) @classmethod - def number(cls, number): + def number(cls, number) -> Literal: return cls(this=str(number), is_string=False) @classmethod - def string(cls, string): + def string(cls, string) -> Literal: return cls(this=str(string), is_string=True) @@ -1087,7 +1133,7 @@ class Properties(Expression): } @classmethod - def from_dict(cls, properties_dict): + def from_dict(cls, properties_dict) -> Properties: expressions = [] for key, value in properties_dict.items(): property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) @@ -1323,7 +1369,7 @@ class Select(Subqueryable): **QUERY_MODIFIERS, } - def from_(self, *expressions, append=True, dialect=None, copy=True, **opts): + def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the FROM expression. @@ -1356,7 +1402,7 @@ class Select(Subqueryable): **opts, ) - def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the GROUP BY expression. @@ -1392,7 +1438,7 @@ class Select(Subqueryable): **opts, ) - def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the ORDER BY expression. @@ -1425,7 +1471,7 @@ class Select(Subqueryable): **opts, ) - def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the SORT BY expression. @@ -1458,7 +1504,7 @@ class Select(Subqueryable): **opts, ) - def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the CLUSTER BY expression. @@ -1491,7 +1537,7 @@ class Select(Subqueryable): **opts, ) - def limit(self, expression, dialect=None, copy=True, **opts): + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: """ Set the LIMIT expression. @@ -1522,7 +1568,7 @@ class Select(Subqueryable): **opts, ) - def offset(self, expression, dialect=None, copy=True, **opts): + def offset(self, expression, dialect=None, copy=True, **opts) -> Select: """ Set the OFFSET expression. @@ -1553,7 +1599,7 @@ class Select(Subqueryable): **opts, ) - def select(self, *expressions, append=True, dialect=None, copy=True, **opts): + def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the SELECT expressions. @@ -1583,7 +1629,7 @@ class Select(Subqueryable): **opts, ) - def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts): + def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the LATERAL expressions. @@ -1626,7 +1672,7 @@ class Select(Subqueryable): dialect=None, copy=True, **opts, - ): + ) -> Select: """ Append to or set the JOIN expressions. @@ -1672,7 +1718,7 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: - natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore if natural: join.set("natural", True) if side: @@ -1681,12 +1727,12 @@ class Select(Subqueryable): join.set("kind", kind.text) if on: - on = and_(*ensure_list(on), dialect=dialect, **opts) + on = and_(*ensure_collection(on), dialect=dialect, **opts) join.set("on", on) if using: join = _apply_list_builder( - *ensure_list(using), + *ensure_collection(using), instance=join, arg="using", append=append, @@ -1705,7 +1751,7 @@ class Select(Subqueryable): **opts, ) - def where(self, *expressions, append=True, dialect=None, copy=True, **opts): + def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the WHERE expressions. @@ -1737,7 +1783,7 @@ class Select(Subqueryable): **opts, ) - def having(self, *expressions, append=True, dialect=None, copy=True, **opts): + def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the HAVING expressions. @@ -1769,7 +1815,7 @@ class Select(Subqueryable): **opts, ) - def distinct(self, distinct=True, copy=True): + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -1788,7 +1834,7 @@ class Select(Subqueryable): instance.set("distinct", Distinct() if distinct else None) return instance - def ctas(self, table, properties=None, dialect=None, copy=True, **opts): + def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create: """ Convert this expression to a CREATE TABLE AS statement. @@ -1826,11 +1872,11 @@ class Select(Subqueryable): ) @property - def named_selects(self): + def named_selects(self) -> t.List[str]: return [e.alias_or_name for e in self.expressions if e.alias_or_name] @property - def selects(self): + def selects(self) -> t.List[Expression]: return self.expressions @@ -1910,12 +1956,16 @@ class Parameter(Expression): pass +class SessionParameter(Expression): + arg_types = {"this": True, "kind": False} + + class Placeholder(Expression): arg_types = {"this": False} class Null(Condition): - arg_types = {} + arg_types: t.Dict[str, t.Any] = {} class Boolean(Condition): @@ -1936,6 +1986,7 @@ class DataType(Expression): NVARCHAR = auto() TEXT = auto() BINARY = auto() + VARBINARY = auto() INT = auto() TINYINT = auto() SMALLINT = auto() @@ -1975,7 +2026,7 @@ class DataType(Expression): UNKNOWN = auto() # Sentinel value, useful for type annotation @classmethod - def build(cls, dtype, **kwargs): + def build(cls, dtype, **kwargs) -> DataType: return DataType( this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], **kwargs, @@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate): pass +class NullSafeEQ(Binary, Predicate): + pass + + +class NullSafeNEQ(Binary, Predicate): + pass + + +class Distance(Binary): + pass + + class Escape(Binary): pass @@ -2101,15 +2164,11 @@ class Is(Binary, Predicate): pass -class Like(Binary, Predicate): - pass - - -class SimilarTo(Binary, Predicate): - pass +class Kwarg(Binary): + """Kwarg in special functions like func(kwarg => y).""" -class Distance(Binary): +class Like(Binary, Predicate): pass @@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate): pass +class SimilarTo(Binary, Predicate): + pass + + class Sub(Binary): pass @@ -2189,7 +2252,13 @@ class Distinct(Expression): class In(Predicate): - arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} + arg_types = { + "this": True, + "expressions": False, + "query": False, + "unnest": False, + "field": False, + } class TimeUnit(Expression): @@ -2255,7 +2324,9 @@ class Func(Condition): @classmethod def sql_names(cls): if cls is Func: - raise NotImplementedError("SQL name is only supported by concrete function implementations") + raise NotImplementedError( + "SQL name is only supported by concrete function implementations" + ) if not hasattr(cls, "_sql_names"): cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} -class DateTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} +class DateTrunc(Func): + arg_types = {"this": True, "expression": True, "zone": False} class DatetimeAdd(Func, TimeUnit): @@ -2791,6 +2862,10 @@ class Year(Func): pass +class Use(Expression): + pass + + def _norm_args(expression): args = {} @@ -2822,7 +2897,7 @@ def maybe_parse( dialect=None, prefix=None, **opts, -): +) -> t.Optional[Expression]: """Gracefully handle a possible string or expression. Example: @@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts): return Except(this=left, expression=right, distinct=distinct) -def select(*expressions, dialect=None, **opts): +def select(*expressions, dialect=None, **opts) -> Select: """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts): return Select().select(*expressions, dialect=dialect, **opts) -def from_(*expressions, dialect=None, **opts): +def from_(*expressions, dialect=None, **opts) -> Select: """ Initializes a syntax tree from a FROM expression. @@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts): return Select().from_(*expressions, dialect=dialect, **opts) -def update(table, properties, where=None, from_=None, dialect=None, **opts): +def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update: """ Creates an update statement. @@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) update.set( "expressions", - [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()], + [ + EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) + for k, v in properties.items() + ], ) if from_: update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) @@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): return update -def delete(table, where=None, dialect=None, **opts): +def delete(table, where=None, dialect=None, **opts) -> Delete: """ Builds a delete statement. @@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts): ) -def condition(expression, dialect=None, **opts): +def condition(expression, dialect=None, **opts) -> Condition: """ Initialize a logical condition expression. @@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts): Returns: Condition: the expression """ - return maybe_parse( + return maybe_parse( # type: ignore expression, into=Condition, dialect=dialect, @@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts): ) -def and_(*expressions, dialect=None, **opts): +def and_(*expressions, dialect=None, **opts) -> And: """ Combine multiple conditions with an AND logical operator. @@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts): return _combine(expressions, And, dialect, **opts) -def or_(*expressions, dialect=None, **opts): +def or_(*expressions, dialect=None, **opts) -> Or: """ Combine multiple conditions with an OR logical operator. @@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts): return _combine(expressions, Or, dialect, **opts) -def not_(expression, dialect=None, **opts): +def not_(expression, dialect=None, **opts) -> Not: """ Wrap a condition with a NOT operator. @@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts): return Not(this=_wrap_operator(this)) -def paren(expression): +def paren(expression) -> Paren: return Paren(this=expression) SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") -def to_identifier(alias, quoted=None): +def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: if alias is None: return None if isinstance(alias, Identifier): @@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None): return identifier -def to_table(sql_path: str, **kwargs) -> Table: +def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. - If a table is passed in then that table is returned. Args: - sql_path(str|Table): `[catalog].[schema].[table]` string + sql_path: a `[catalog].[schema].[table]` string. + Returns: - Table: A table expression + A table expression. """ if sql_path is None or isinstance(sql_path, Table): return sql_path @@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts): return Select().from_(expression, dialect=dialect, **opts) -def column(col, table=None, quoted=None): +def column(col, table=None, quoted=None) -> Column: """ Build a Column. Args: @@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None): ) -def table_(table, db=None, catalog=None, quoted=None, alias=None): +def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. Args: @@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None): ) -def values(values, alias=None): +def values(values, alias=None) -> Values: """Build VALUES statement. Example: @@ -3449,7 +3527,7 @@ def values(values, alias=None): ) -def convert(value): +def convert(value) -> Expression: """Convert a python value into an expression object. Raises an error if a conversion is not possible. @@ -3500,15 +3578,14 @@ def replace_children(expression, fun): for cn in child_nodes: if isinstance(cn, Expression): - cns = ensure_list(fun(cn)) - for child_node in cns: + for child_node in ensure_collection(fun(cn)): new_child_nodes.append(child_node) child_node.parent = expression child_node.arg_key = k else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) + expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) def column_table_names(expression): @@ -3529,7 +3606,7 @@ def column_table_names(expression): return list(dict.fromkeys(column.table for column in expression.find_all(Column))) -def table_name(table): +def table_name(table) -> str: """Get the full name of a table as a string. Args: @@ -3546,6 +3623,9 @@ def table_name(table): table = maybe_parse(table, into=Table) + if not table: + raise ValueError(f"Cannot parse {table}") + return ".".join( part for part in ( diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ca14425..11d9073 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import logging +import re +import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors @@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") +NEWLINE_RE = re.compile("\r\n?|\n") + class Generator: """ @@ -47,8 +53,7 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - annotations: Whether or not to show annotations in the SQL when `pretty` is True. - Annotations can only be shown in pretty mode otherwise they may clobber resulting sql. + comments: Whether or not to preserve comments in the ouput SQL code. Default: True """ @@ -65,14 +70,16 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } - # whether 'CREATE ... TRANSIENT ... TABLE' is allowed - # can override in dialects + # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed CREATE_TRANSIENT = False - # whether or not null ordering is supported in order by + + # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True - # always do union distinct or union all + + # Always do union distinct or union all EXPLICIT_UNION = False - # wrap derived values in parens, usually standard but spark doesn't support it + + # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True TYPE_MAPPING = { @@ -80,7 +87,7 @@ class Generator: exp.DataType.Type.NVARCHAR: "VARCHAR", } - TOKEN_MAPPING = {} + TOKEN_MAPPING: t.Dict[TokenType, str] = {} STRUCT_DELIMITER = ("<", ">") @@ -96,6 +103,8 @@ class Generator: exp.TableFormatProperty, } + WITH_SEPARATED_COMMENTS = (exp.Select,) + __slots__ = ( "time_mapping", "time_trie", @@ -122,7 +131,7 @@ class Generator: "_escaped_quote_end", "_leading_comma", "_max_text_width", - "_annotations", + "_comments", ) def __init__( @@ -148,7 +157,7 @@ class Generator: max_unsupported=3, leading_comma=False, max_text_width=80, - annotations=True, + comments=True, ): import sqlglot @@ -177,7 +186,7 @@ class Generator: self._escaped_quote_end = self.escape + self.quote_end self._leading_comma = leading_comma self._max_text_width = max_text_width - self._annotations = annotations + self._comments = comments def generate(self, expression): """ @@ -204,7 +213,6 @@ class Generator: return sql def unsupported(self, message): - if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) @@ -215,9 +223,31 @@ class Generator: def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" + def maybe_comment(self, sql, expression, single_line=False): + comment = expression.comment if self._comments else None + + if not comment: + return sql + + comment = " " + comment if comment[0].strip() else comment + comment = comment + " " if comment[-1].strip() else comment + + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"/*{comment}*/{self.sep()}{sql}" + + if not self.pretty: + return f"{sql} /*{comment}*/" + + if not NEWLINE_RE.search(comment): + return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + + return f"/*{comment}*/\n{sql}" + def wrap(self, expression): this_sql = self.indent( - self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), + self.sql(expression) + if isinstance(expression, (exp.Select, exp.Union)) + else self.sql(expression, "this"), level=1, pad=0, ) @@ -251,7 +281,7 @@ class Generator: for i, line in enumerate(lines) ) - def sql(self, expression, key=None): + def sql(self, expression, key=None, comment=True): if not expression: return "" @@ -264,29 +294,24 @@ class Generator: transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): - return transform(self, expression) - if transform: - return transform - - if not isinstance(expression, exp.Expression): + sql = transform(self, expression) + elif transform: + sql = transform + elif isinstance(expression, exp.Expression): + exp_handler_name = f"{expression.key}_sql" + + if hasattr(self, exp_handler_name): + sql = getattr(self, exp_handler_name)(expression) + elif isinstance(expression, exp.Func): + sql = self.function_fallback_sql(expression) + elif isinstance(expression, exp.Property): + sql = self.property_sql(expression) + else: + raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") + else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - exp_handler_name = f"{expression.key}_sql" - if hasattr(self, exp_handler_name): - return getattr(self, exp_handler_name)(expression) - - if isinstance(expression, exp.Func): - return self.function_fallback_sql(expression) - - if isinstance(expression, exp.Property): - return self.property_sql(expression) - - raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") - - def annotation_sql(self, expression): - if self._annotations and self.pretty: - return f"{self.sql(expression, 'expression')} # {expression.name}" - return self.sql(expression, "expression") + return self.maybe_comment(sql, expression) if self._comments and comment else sql def uncache_sql(self, expression): table = self.sql(expression, "this") @@ -371,7 +396,9 @@ class Generator: expression_sql = self.sql(expression, "expression") expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" - transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" + transient = ( + " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" + ) replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" @@ -434,7 +461,9 @@ class Generator: def delete_sql(self, expression): this = self.sql(expression, "this") using_sql = ( - f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else "" + f" USING {self.expressions(expression, 'using', sep=', USING ')}" + if expression.args.get("using") + else "" ) where_sql = self.sql(expression, "where") sql = f"DELETE FROM {this}{using_sql}{where_sql}" @@ -481,15 +510,18 @@ class Generator: return f"{this} ON {table} {columns}" def identifier_sql(self, expression): - value = expression.name - value = value.lower() if self.normalize else value + text = expression.name + text = text.lower() if self.normalize else text if expression.args.get("quoted") or self.identify: - return f"{self.identifier_start}{value}{self.identifier_end}" - return value + text = f"{self.identifier_start}{text}{self.identifier_end}" + return text def partition_sql(self, expression): keys = csv( - *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] + *[ + f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name + for prop in expression.this + ] ) return f"PARTITION({keys})" @@ -504,9 +536,9 @@ class Generator: elif p_class in self.ROOT_PROPERTIES: root_properties.append(p) - return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( - exp.Properties(expressions=with_properties) - ) + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) def root_properties(self, properties): if properties.expressions: @@ -551,7 +583,9 @@ class Generator: this = f"{this}{self.sql(expression, 'this')}" exists = " IF EXISTS " if expression.args.get("exists") else " " - partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" + partition_sql = ( + self.sql(expression, "partition") if expression.args.get("partition") else "" + ) expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" @@ -669,7 +703,9 @@ class Generator: def group_sql(self, expression): group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) - grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" + grouping_sets = ( + f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" + ) cube = self.expressions(expression, key="cube", indent=False) cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" rollup = self.expressions(expression, key="rollup", indent=False) @@ -711,10 +747,10 @@ class Generator: this_sql = self.sql(expression, "this") return f"{expression_sql}{op_sql} {this_sql}{on_sql}" - def lambda_sql(self, expression): + def lambda_sql(self, expression, arrow_sep="->"): args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args - return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}") + return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") def lateral_sql(self, expression): this = self.sql(expression, "this") @@ -748,7 +784,7 @@ class Generator: if self._replace_backslash: text = text.replace("\\", "\\\\") text = text.replace(self.quote_end, self._escaped_quote_end) - return f"{self.quote_start}{text}{self.quote_end}" + text = f"{self.quote_start}{text}{self.quote_end}" return text def loaddata_sql(self, expression): @@ -796,13 +832,21 @@ class Generator: sort_order = " DESC" if desc else "" nulls_sort_change = "" - if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): + if nulls_first and ( + (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last + ): nulls_sort_change = " NULLS FIRST" - elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: + elif ( + nulls_last + and ((asc and nulls_are_small) or (desc and nulls_are_large)) + and not nulls_are_last + ): nulls_sort_change = " NULLS LAST" if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") + self.unsupported( + "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" + ) nulls_sort_change = "" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" @@ -835,7 +879,7 @@ class Generator: sql = self.query_modifiers( expression, f"SELECT{hint}{distinct}{expressions}", - self.sql(expression, "from"), + self.sql(expression, "from", comment=False), ) return self.prepend_ctes(expression, sql) @@ -858,6 +902,13 @@ class Generator: def parameter_sql(self, expression): return f"@{self.sql(expression, 'this')}" + def sessionparameter_sql(self, expression): + this = self.sql(expression, "this") + kind = expression.text("kind") + if kind: + kind = f"{kind}." + return f"@@{kind}{this}" + def placeholder_sql(self, expression): return f":{expression.name}" if expression.name else "?" @@ -931,7 +982,10 @@ class Generator: def window_spec_sql(self, expression): kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") - end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" + end = ( + csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") + or "CURRENT ROW" + ) return f"{kind} BETWEEN {start} AND {end}" def withingroup_sql(self, expression): @@ -1020,7 +1074,9 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))) + return self.case_sql( + exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) + ) def in_sql(self, expression): query = expression.args.get("query") @@ -1196,6 +1252,12 @@ class Generator: def neq_sql(self, expression): return self.binary(expression, "<>") + def nullsafeeq_sql(self, expression): + return self.binary(expression, "IS NOT DISTINCT FROM") + + def nullsafeneq_sql(self, expression): + return self.binary(expression, "IS DISTINCT FROM") + def or_sql(self, expression): return self.connector_sql(expression, "OR") @@ -1205,6 +1267,9 @@ class Generator: def trycast_sql(self, expression): return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + def use_sql(self, expression): + return f"USE {self.sql(expression, 'this')}" + def binary(self, expression, op): return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" @@ -1240,17 +1305,27 @@ class Generator: if flat: return sep.join(self.sql(e) for e in expressions) - sql = (self.sql(e) for e in expressions) - # the only time leading_comma changes the output is if pretty print is enabled - if self._leading_comma and self.pretty: - pad = " " * self.pad - expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql)) - else: - expressions = self.sep(sep).join(sql) + num_sqls = len(expressions) + + # These are calculated once in case we have the leading_comma / pretty option set, correspondingly + pad = " " * self.pad + stripped_sep = sep.strip() - if indent: - return self.indent(expressions, skip_first=False) - return expressions + result_sqls = [] + for i, e in enumerate(expressions): + sql = self.sql(e, comment=False) + comment = self.maybe_comment("", e, single_line=True) + + if self.pretty: + if self._leading_comma: + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + else: + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + else: + result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + + result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) + return self.indent(result_sqls, skip_first=False) if indent else result_sqls def op_expressions(self, op, expression, flat=False): expressions_sql = self.expressions(expression, flat=flat) @@ -1264,7 +1339,9 @@ class Generator: def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) - return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") + return self.query_modifiers( + expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" + ) def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) @@ -1283,3 +1360,6 @@ class Generator: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"{this}({expressions})" + + def kwarg_sql(self, expression): + return self.binary(expression, "=>") diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 42965d1..379c2e7 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -1,48 +1,125 @@ +from __future__ import annotations + import inspect import logging import re import sys import typing as t +from collections.abc import Collection from contextlib import contextmanager from copy import copy from enum import Enum +if t.TYPE_CHECKING: + from sqlglot.expressions import Expression, Table + + T = t.TypeVar("T") + E = t.TypeVar("E", bound=Expression) + CAMEL_CASE_PATTERN = re.compile("(? t.Optional[T]: + """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" try: - return arr[index] + return seq[index] except IndexError: return None +@t.overload +def ensure_list(value: t.Collection[T]) -> t.List[T]: + ... + + +@t.overload +def ensure_list(value: T) -> t.List[T]: + ... + + def ensure_list(value): + """ + Ensures that a value is a list, otherwise casts or wraps it into one. + + Args: + value: the value of interest. + + Returns: + The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. + """ if value is None: return [] - return value if isinstance(value, (list, tuple, set)) else [value] + elif isinstance(value, (list, tuple)): + return list(value) + + return [value] + + +@t.overload +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: + ... -def csv(*args, sep=", "): +@t.overload +def ensure_collection(value: T) -> t.Collection[T]: + ... + + +def ensure_collection(value): + """ + Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. + + Args: + value: the value of interest. + + Returns: + The value if it's a collection, or else the value wrapped in a list. + """ + if value is None: + return [] + return ( + value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] + ) + + +def csv(*args, sep: str = ", ") -> str: + """ + Formats any number of string arguments as CSV. + + Args: + args: the string arguments to format. + sep: the argument separator. + + Returns: + The arguments formatted as a CSV string. + """ return sep.join(arg for arg in args if arg) -def subclasses(module_name, classes, exclude=()): +def subclasses( + module_name: str, + classes: t.Type | t.Tuple[t.Type, ...], + exclude: t.Type | t.Tuple[t.Type, ...] = (), +) -> t.List[t.Type]: """ - Returns a list of all subclasses for a specified class set, posibly excluding some of them. + Returns all subclasses for a collection of classes, possibly excluding some of them. Args: - module_name (str): The name of the module to search for subclasses in. - classes (type|tuple[type]): Class(es) we want to find the subclasses of. - exclude (type|tuple[type]): Class(es) we want to exclude from the returned list. + module_name: the name of the module to search for subclasses in. + classes: class(es) we want to find the subclasses of. + exclude: class(es) we want to exclude from the returned list. + Returns: - A list of all the target subclasses. + The target subclasses. """ return [ obj @@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()): ] -def apply_index_offset(expressions, offset): +def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: + """ + Applies an offset to a given integer literal expression. + + Args: + expressions: the expression the offset will be applied to, wrapped in a list. + offset: the offset that will be applied. + + Returns: + The original expression with the offset applied to it, wrapped in a list. If the provided + `expressions` argument contains more than one expressions, it's returned unaffected. + """ if not offset or len(expressions) != 1: return expressions @@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset): logger.warning("Applying array index offset (%s)", offset) expression.args["this"] = str(int(expression.args["this"]) + offset) return [expression] + return expressions -def camel_to_snake_case(name): +def camel_to_snake_case(name: str) -> str: + """Converts `name` from camelCase to snake_case and returns the result.""" return CAMEL_CASE_PATTERN.sub("_", name).upper() -def while_changing(expression, func): +def while_changing( + expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E] +) -> E: + """ + Applies a transformation to a given expression until a fix point is reached. + + Args: + expression: the expression to be transformed. + func: the transformation to be applied. + + Returns: + The transformed expression. + """ while True: start = hash(expression) expression = func(expression) @@ -80,10 +182,19 @@ def while_changing(expression, func): return expression -def tsort(dag): +def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: + """ + Sorts a given directed acyclic graph in topological order. + + Args: + dag: the graph to be sorted. + + Returns: + A list that contains all of the graph's nodes in topological order. + """ result = [] - def visit(node, visited): + def visit(node: T, visited: t.Set[T]) -> None: if node in result: return if node in visited: @@ -103,10 +214,8 @@ def tsort(dag): return result -def open_file(file_name): - """ - Open a file that may be compressed as gzip and return in newline mode. - """ +def open_file(file_name: str) -> t.TextIO: + """Open a file that may be compressed as gzip and return it in universal newline mode.""" with open(file_name, "rb") as f: gzipped = f.read(2) == b"\x1f\x8b" @@ -119,14 +228,14 @@ def open_file(file_name): @contextmanager -def csv_reader(table): +def csv_reader(table: Table) -> t.Any: """ - Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) + Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. Args: - table (exp.Table): A table expression with an anonymous function READ_CSV in it + table: a `Table` expression with an anonymous function `READ_CSV` in it. - Returns: + Yields: A python csv reader. """ file, *args = table.this.expressions @@ -147,13 +256,16 @@ def csv_reader(table): file.close() -def find_new_name(taken, base): +def find_new_name(taken: t.Sequence[str], base: str) -> str: """ Searches for a new name. Args: - taken (Sequence[str]): set of taken names - base (str): base name to alter + taken: a collection of taken names. + base: base name to alter. + + Returns: + The new, available name. """ if base not in taken: return base @@ -163,22 +275,26 @@ def find_new_name(taken, base): while new in taken: i += 1 new = f"{base}_{i}" + return new -def object_to_dict(obj, **kwargs): +def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: + """Returns a dictionary created from an object's attributes.""" return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} -def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]: +def split_num_words( + value: str, sep: str, min_num_words: int, fill_from_start: bool = True +) -> t.List[t.Optional[str]]: """ - Perform a split on a value and return N words as a result with None used for words that don't exist. + Perform a split on a value and return N words as a result with `None` used for words that don't exist. Args: - value: The value to be split - sep: The value to use to split on - min_num_words: The minimum number of words that are going to be in the result - fill_from_start: Indicates that if None values should be inserted at the start or end of the list + value: the value to be split. + sep: the value to use to split on. + min_num_words: the minimum number of words that are going to be in the result. + fill_from_start: indicates that if `None` values should be inserted at the start or end of the list. Examples: >>> split_num_words("db.table", ".", 3) @@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table'] + + Returns: + The list of words returned by `split`, possibly augmented by a number of `None` values. """ words = value.split(sep) if fill_from_start: @@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b def is_iterable(value: t.Any) -> bool: """ - Checks if the value is an iterable but does not include strings and bytes + Checks if the value is an iterable, excluding the types `str` and `bytes`. Examples: >>> is_iterable([1,2]) @@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool: False Args: - value: The value to check if it is an interable + value: the value to check if it is an iterable. - Returns: Bool indicating if it is an iterable + Returns: + A `bool` value indicating if it is an iterable. """ return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) -def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: +def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]: """ - Flattens a list that can contain both iterables and non-iterable elements + Flattens an iterable that can contain both iterable and non-iterable elements. Objects of + type `str` and `bytes` are not regarded as iterables. Examples: - >>> list(flatten([[1, 2], 3])) - [1, 2, 3] + >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) + [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3] Args: - values: The value to be flattened + values: the value to be flattened. - Returns: - Yields non-iterable elements (not including str or byte as iterable) + Yields: + Non-iterable elements in `values`. """ for value in values: if is_iterable(value): diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 30055bc..96331e2 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_list, subclasses +from sqlglot.helper import ensure_collection, ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -48,35 +48,65 @@ class TypeAnnotator: exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BIGINT + ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.CurrentTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.TimestampSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), + exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), + exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), + exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.GroupConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.ArrayConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -88,32 +118,52 @@ class TypeAnnotator: exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), + exp.RegexpLike: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BOOLEAN + ), exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.StrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } @@ -124,7 +174,11 @@ class TypeAnnotator: exp.DataType.Type.TEXT: set(), exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.NCHAR: { + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.TEXT, + }, exp.DataType.Type.CHAR: { exp.DataType.Type.NCHAR, exp.DataType.Type.VARCHAR, @@ -135,7 +189,11 @@ class TypeAnnotator: exp.DataType.Type.DOUBLE: set(), exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.BIGINT: { + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, exp.DataType.Type.INT: { exp.DataType.Type.BIGINT, exp.DataType.Type.DECIMAL, @@ -160,7 +218,10 @@ class TypeAnnotator: # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ exp.DataType.Type.TIMESTAMPLTZ: set(), exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.TIMESTAMP: { + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, exp.DataType.Type.DATETIME: { exp.DataType.Type.TIMESTAMP, exp.DataType.Type.TIMESTAMPTZ, @@ -219,7 +280,7 @@ class TypeAnnotator: def _annotate_args(self, expression): for value in expression.args.values(): - for v in ensure_list(value): + for v in ensure_collection(value): self._maybe_annotate(v) return expression @@ -243,7 +304,9 @@ class TypeAnnotator: if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: expression.type = exp.DataType.Type.NULL elif exp.DataType.Type.NULL in (left_type, right_type): - expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + expression.type = exp.DataType.build( + "NULLABLE", expressions=exp.DataType.build("BOOLEAN") + ) else: expression.type = exp.DataType.Type.BOOLEAN elif isinstance(expression, (exp.Condition, exp.Predicate)): @@ -276,3 +339,17 @@ class TypeAnnotator: def _annotate_with_type(self, expression, target_type): expression.type = target_type return self._annotate_args(expression) + + def _annotate_by_args(self, expression, *args): + self._annotate_args(expression) + expressions = [] + for arg in args: + arg_expr = expression.args.get(arg) + expressions.extend(expr for expr in ensure_list(arg_expr) if expr) + + last_datatype = None + for expr in expressions: + last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) + + expression.type = last_datatype or exp.DataType.Type.UNKNOWN + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 0854336..29621af 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias): on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) else: on_clause_columns = set() - return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) + return any( + column for column in scope.source_columns(alias) if id(column) not in on_clause_columns + ) def _is_joined_on_all_unique_outputs(scope, join): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index e30c263..8704e90 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -45,7 +45,13 @@ def eliminate_subqueries(expression): # All table names are taken for scope in root.traverse(): - taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)}) + taken.update( + { + source.name: source + for _, source in scope.sources.items() + if isinstance(source, exp.Table) + } + ) # Map of Expression->alias # Existing CTES in the root expression. We'll use this for deduplication. @@ -70,7 +76,9 @@ def eliminate_subqueries(expression): new_ctes.append(cte_scope.expression.parent) # Now append the rest - for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes): + for scope in itertools.chain( + root.union_scopes, root.subquery_scopes, root.derived_table_scopes + ): for child_scope in scope.traverse(): new_cte = _eliminate(child_scope, existing_ctes, taken) if new_cte: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 70e4629..9ae4966 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): unmergable_window_columns = [ column for column in outer_scope.columns - if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc) + if column.find_ancestor( + exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc + ) ] window_expressions_in_unmergable = [ column @@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): and not ( isinstance(from_or_join, exp.From) and inner_select.args.get("where") - and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + and any( + j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) + ) ) and not _is_a_window_expression_in_unmergable_operation() ) @@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): if table.alias_or_name == node_to_replace.alias_or_name: table.set("this", exp.to_identifier(new_subquery.alias_or_name)) outer_scope.remove_source(alias) - outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + outer_scope.add_source( + new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] + ) def _merge_joins(outer_scope, inner_scope, from_or_join): @@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope): inner_scope (sqlglot.optimizer.scope.Scope) """ if ( - any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + any( + outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] + ) or len(outer_scope.selected_sources) != 1 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) ): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index ab30d7a..db538ef 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False): Returns: int: difference """ - return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) + return sum(_predicate_lengths(expression, dnf)) - ( + len(list(expression.find_all(exp.Connector))) + 1 + ) def _predicate_lengths(expression, dnf): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 0c74e36..40e4ab1 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -68,4 +68,8 @@ def normalize(expression): def other_table_names(join, exclude): - return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 5ad8f46..b2ed062 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames - rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + rule_kwargs = { + param: possible_kwargs[param] for param in rule_params if param in possible_kwargs + } expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 583d059..6364f65 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) if cnf_like: pushdown_cnf(predicates, sources, scope_ref_count) @@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition ) for name, node in nodes.items(): @@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: # We can't push down window expressions - has_window_expression = any(select for select in node.selects if select.find(exp.Window)) + has_window_expression = any( + select for select in node.selects if select.find(exp.Window) + ) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: + if ( + not node.args.get("group") + and scope_ref_count[id(source)] < 2 + and not has_window_expression + ): nodes[table] = node return nodes @@ -165,7 +181,7 @@ def replace_aliases(source, predicate): def _replace_alias(column): if isinstance(column, exp.Column) and column.name in aliases: - return aliases[column.name] + return aliases[column.name].copy() return column return predicate.transform(_replace_alias) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 5820851..abd9492 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections): def _remove_indexed_selections(scope, indexes_to_remove): - new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove] + new_selections = [ + selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove + ] if not new_selections: new_selections.append(DEFAULT_SELECTION) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ebee92a..69fe2b8 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver): # Determine whether each reference in the order by clause is to a column or an alias. for ordered in scope.find_all(exp.Ordered): for column in ordered.find_all(exp.Column): - if not column.table and column.parent is not ordered and column.name in resolver.all_columns: + if ( + not column.table + and column.parent is not ordered + and column.name in resolver.all_columns + ): columns_missing_from_scope.append(column) # Determine whether each reference in the having clause is to a column or an alias. for having in scope.find_all(exp.Having): for column in having.find_all(exp.Column): - if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns: + if ( + not column.table + and column.find_ancestor(exp.AggFunc) + and column.name in resolver.all_columns + ): columns_missing_from_scope.append(column) for column in columns_missing_from_scope: @@ -295,7 +303,9 @@ def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = [] - for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): if isinstance(selection, exp.Column): # convoluted setter because a simple selection.replace(alias) would require a copy alias_ = alias(exp.column(""), alias=selection.name) @@ -343,14 +353,18 @@ class _Resolver: (str) table name """ if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) return self._unambiguous_columns.get(column_name) @property def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) + self._all_columns = set( + column for columns in self._get_all_source_columns().values() for column in columns + ) return self._all_columns def get_source_columns(self, name, only_visible=False): @@ -377,7 +391,9 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: - self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } return self._source_columns def _get_unambiguous_columns(self, source_columns): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 5a75ee2..18848f3 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -226,7 +226,9 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] + external_columns = [ + column for scope in self.subquery_scopes for column in scope.external_columns + ] named_outputs = {e.alias_or_name for e in self.expression.expressions} @@ -278,7 +280,11 @@ class Scope: Returns: dict[str, Scope]: Mapping of source alias to Scope """ - return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte} + return { + alias: scope + for alias, scope in self.sources.items() + if isinstance(scope, Scope) and scope.is_cte + } @property def selects(self): @@ -307,7 +313,9 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] return self._external_columns @property diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c077906..d759e86 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -229,7 +229,9 @@ def simplify_literals(expression): operands.append(a) if len(operands) < size: - return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b): return TRUE if not_ else FALSE if a == NULL: return FALSE if not_ else TRUE + elif isinstance(expression, exp.NullSafeEQ): + if a == b: + return TRUE + elif isinstance(expression, exp.NullSafeNEQ): + if a == b: + return FALSE elif NULL in (a, b): return NULL @@ -357,7 +365,7 @@ def extract_date(cast): def extract_interval(interval): try: - from dateutil.relativedelta import relativedelta + from dateutil.relativedelta import relativedelta # type: ignore except ModuleNotFoundError: return None diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 11c6eba..f41a84e 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence): return if isinstance(predicate, exp.Binary): - key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) else: return @@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence): else: parent_predicate = _replace(parent_predicate, "TRUE") elif isinstance(parent_predicate, exp.All): - parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) elif isinstance(parent_predicate, exp.Any): if value.this in group_by: parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") @@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 79a1d90..bbea0e5 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import logging +import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_errors -from sqlglot.helper import apply_index_offset, ensure_list, list_get +from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType +from sqlglot.trie import in_trie, new_trie logger = logging.getLogger("sqlglot") @@ -20,7 +24,15 @@ def parse_var_map(args): ) -class Parser: +class _Parser(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) + klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) + return klass + + +class Parser(metaclass=_Parser): """ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` and produces a parsed syntax tree. @@ -45,16 +57,16 @@ class Parser: FUNCTIONS = { **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, "DATE_TO_DATE_STR": lambda args: exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "TIME_TO_TIME_STR": lambda args: exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( this=exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), start=exp.Literal.number(1), @@ -90,6 +102,7 @@ class Parser: TokenType.NVARCHAR, TokenType.TEXT, TokenType.BINARY, + TokenType.VARBINARY, TokenType.JSON, TokenType.INTERVAL, TokenType.TIMESTAMP, @@ -243,6 +256,7 @@ class Parser: EQUALITY = { TokenType.EQ: exp.EQ, TokenType.NEQ: exp.NEQ, + TokenType.NULLSAFE_EQ: exp.NullSafeEQ, } COMPARISON = { @@ -298,6 +312,21 @@ class Parser: TokenType.ANTI, } + LAMBDAS = { + TokenType.ARROW: lambda self, expressions: self.expression( + exp.Lambda, + this=self._parse_conjunction().transform( + self._replace_lambda, {node.name for node in expressions} + ), + expressions=expressions, + ), + TokenType.FARROW: lambda self, expressions: self.expression( + exp.Kwarg, + this=exp.Var(this=expressions[0].name), + expression=self._parse_conjunction(), + ), + } + COLUMN_OPERATORS = { TokenType.DOT: None, TokenType.DCOLON: lambda self, this, to: self.expression( @@ -362,20 +391,30 @@ class Parser: TokenType.DELETE: lambda self: self._parse_delete(), TokenType.CACHE: lambda self: self._parse_cache(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.USE: lambda self: self._parse_use(), } PRIMARY_PARSERS = { - TokenType.STRING: lambda _, token: exp.Literal.string(token.text), - TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), - TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), - TokenType.NULL: lambda *_: exp.Null(), - TokenType.TRUE: lambda *_: exp.Boolean(this=True), - TokenType.FALSE: lambda *_: exp.Boolean(this=False), - TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), - TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), - TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), - TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text), + TokenType.STRING: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=True + ), + TokenType.NUMBER: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=False + ), + TokenType.STAR: lambda self, _: self.expression( + exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + ), + TokenType.NULL: lambda self, _: self.expression(exp.Null), + TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), + TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), + TokenType.PARAMETER: lambda self, _: self.expression( + exp.Parameter, this=self._parse_var() or self._parse_primary() + ), + TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), + TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), + TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } RANGE_PARSERS = { @@ -411,16 +450,24 @@ class Parser: TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), - TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty), + TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment( + exp.TableFormatProperty + ), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), TokenType.EXECUTE: lambda self: self._parse_execute_as(), TokenType.DETERMINISTIC: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), - TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")), - TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")), - TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")), + TokenType.IMMUTABLE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + ), + TokenType.STABLE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("STABLE") + ), + TokenType.VOLATILE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + ), } CONSTRAINT_PARSERS = { @@ -450,7 +497,8 @@ class Parser: "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), + "window": lambda self: self._match(TokenType.WINDOW) + and self._parse_window(self._parse_id_var(), alias=True), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), @@ -459,6 +507,9 @@ class Parser: "offset": lambda self: self._parse_offset(), } + SHOW_PARSERS: t.Dict[str, t.Callable] = {} + SET_PARSERS: t.Dict[str, t.Callable] = {} + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) CREATABLES = { @@ -488,7 +539,9 @@ class Parser: "_curr", "_next", "_prev", - "_greedy_subqueries", + "_prev_comment", + "_show_trie", + "_set_trie", ) def __init__( @@ -519,7 +572,7 @@ class Parser: self._curr = None self._next = None self._prev = None - self._greedy_subqueries = False + self._prev_comment = None def parse(self, raw_tokens, sql=None): """ @@ -533,10 +586,12 @@ class Parser: Returns the list of syntax trees (:class:`~sqlglot.expressions.Expression`). """ - return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) + return self._parse( + parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql + ) def parse_into(self, expression_types, raw_tokens, sql=None): - for expression_type in ensure_list(expression_types): + for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: raise TypeError(f"No parser registered for {expression_type}") @@ -597,6 +652,9 @@ class Parser: def expression(self, exp_class, **kwargs): instance = exp_class(**kwargs) + if self._prev_comment: + instance.comment = self._prev_comment + self._prev_comment = None self.validate_expression(instance) return instance @@ -633,14 +691,16 @@ class Parser: return index - def _get_token(self, index): - return list_get(self._tokens, index) - def _advance(self, times=1): self._index += times - self._curr = self._get_token(self._index) - self._next = self._get_token(self._index + 1) - self._prev = self._get_token(self._index - 1) if self._index > 0 else None + self._curr = seq_get(self._tokens, self._index) + self._next = seq_get(self._tokens, self._index + 1) + if self._index > 0: + self._prev = self._tokens[self._index - 1] + self._prev_comment = self._prev.comment + else: + self._prev = None + self._prev_comment = None def _retreat(self, index): self._advance(index - self._index) @@ -661,6 +721,7 @@ class Parser: expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() + self._parse_query_modifiers(expression) return expression @@ -682,7 +743,11 @@ class Parser: ) def _parse_exists(self, not_=False): - return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) + return ( + self._match(TokenType.IF) + and (not not_ or self._match(TokenType.NOT)) + and self._match(TokenType.EXISTS) + ) def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) @@ -931,7 +996,9 @@ class Parser: return self.expression( exp.Delete, this=self._parse_table(schema=True), - using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)), + using=self._parse_csv( + lambda: self._match(TokenType.USING) and self._parse_table(schema=True) + ), where=self._parse_where(), ) @@ -983,11 +1050,13 @@ class Parser: return None def parse_values(): - k = self._parse_var() + key = self._parse_var() + value = None + if self._match(TokenType.EQ): - v = self._parse_string() - return (k, v) - return (k, None) + value = self._parse_string() + + return exp.Property(this=key, value=value) self._match_l_paren() values = self._parse_csv(parse_values) @@ -1019,6 +1088,8 @@ class Parser: self.raise_error(f"{this.key} does not support CTE") this = cte elif self._match(TokenType.SELECT): + comment = self._prev_comment + hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -1033,7 +1104,7 @@ class Parser: self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) + expressions = self._parse_csv(self._parse_expression) this = self.expression( exp.Select, @@ -1042,6 +1113,7 @@ class Parser: expressions=expressions, limit=limit, ) + this.comment = comment from_ = self._parse_from() if from_: this.set("from", from_) @@ -1072,8 +1144,10 @@ class Parser: while True: expressions.append(self._parse_cte()) - if not self._match(TokenType.COMMA): + if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH): break + else: + self._match(TokenType.WITH) return self.expression( exp.With, @@ -1111,11 +1185,7 @@ class Parser: if not alias and not columns: return None - return self.expression( - exp.TableAlias, - this=alias, - columns=columns, - ) + return self.expression(exp.TableAlias, this=alias, columns=columns) def _parse_subquery(self, this): return self.expression( @@ -1150,12 +1220,6 @@ class Parser: if expression: this.set(key, expression) - def _parse_annotation(self, expression): - if self._match(TokenType.ANNOTATION): - return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression) - - return expression - def _parse_hint(self): if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) @@ -1295,7 +1359,9 @@ class Parser: if not table: self.raise_error("Expected table name") - this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()) + this = self.expression( + exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() + ) if schema: return self._parse_schema(this=this) @@ -1500,7 +1566,9 @@ class Parser: if not skip_order_token and not self._match(TokenType.ORDER_BY): return this - return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) + return self.expression( + exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) + ) def _parse_sort(self, token_type, exp_class): if not self._match(token_type): @@ -1521,7 +1589,8 @@ class Parser: if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") + (asc and self.null_ordering == "nulls_are_small") + or (desc and self.null_ordering != "nulls_are_small") ) and self.null_ordering != "nulls_are_last" ): @@ -1606,6 +1675,9 @@ class Parser: def _parse_is(self, this): negate = self._match(TokenType.NOT) + if self._match(TokenType.DISTINCT_FROM): + klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ + return self.expression(klass, this=this, expression=self._parse_expression()) this = self.expression( exp.Is, this=this, @@ -1653,9 +1725,13 @@ class Parser: expression=self._parse_term(), ) elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) + this = self.expression( + exp.BitwiseLeftShift, this=this, expression=self._parse_term() + ) elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) + this = self.expression( + exp.BitwiseRightShift, this=this, expression=self._parse_term() + ) else: break @@ -1685,7 +1761,7 @@ class Parser: ) index = self._index - type_token = self._parse_types() + type_token = self._parse_types(check_func=True) this = self._parse_column() if type_token: @@ -1698,7 +1774,7 @@ class Parser: return this - def _parse_types(self): + def _parse_types(self, check_func=False): index = self._index if not self._match_set(self.TYPE_TOKENS): @@ -1708,10 +1784,13 @@ class Parser: nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token == TokenType.STRUCT expressions = None + maybe_func = False if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): return exp.DataType( - this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True + this=exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(type_token.value)], + nested=True, ) if self._match(TokenType.L_BRACKET): @@ -1731,6 +1810,7 @@ class Parser: return None self._match_r_paren() + maybe_func = True if nested and self._match(TokenType.LT): if is_struct: @@ -1741,25 +1821,46 @@ class Parser: if not self._match(TokenType.GT): self.raise_error("Expecting >") + value = None if type_token in self.TIMESTAMPS: - tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ - if tz: - return exp.DataType( + if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: + value = exp.DataType( this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions, ) - ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ - if ltz: - return exp.DataType( + elif ( + self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ + ): + value = exp.DataType( this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions, ) - self._match(TokenType.WITHOUT_TIME_ZONE) + elif self._match(TokenType.WITHOUT_TIME_ZONE): + value = exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) - return exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + maybe_func = maybe_func and value is None + + if value is None: + value = exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) + + if maybe_func and check_func: + index2 = self._index + peek = self._parse_string() + + if not peek: + self._retreat(index) + return None + + self._retreat(index2) + + if value: + return value return exp.DataType( this=exp.DataType.Type[type_token.value.upper()], @@ -1826,22 +1927,29 @@ class Parser: return exp.Literal.number(f"0.{self._prev.text}") if self._match(TokenType.L_PAREN): + comment = self._prev_comment query = self._parse_select() if query: expressions = [query] else: - expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) + expressions = self._parse_csv( + lambda: self._parse_alias(self._parse_conjunction(), explicit=True) + ) - this = list_get(expressions, 0) + this = seq_get(expressions, 0) self._parse_query_modifiers(this) self._match_r_paren() if isinstance(this, exp.Subqueryable): - return self._parse_set_operations(self._parse_subquery(this)) - if len(expressions) > 1: - return self.expression(exp.Tuple, expressions=expressions) - return self.expression(exp.Paren, this=this) + this = self._parse_set_operations(self._parse_subquery(this)) + elif len(expressions) > 1: + this = self.expression(exp.Tuple, expressions=expressions) + else: + this = self.expression(exp.Paren, this=this) + if comment: + this.comment = comment + return this return None @@ -1894,7 +2002,8 @@ class Parser: self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) - self._match_r_paren() + + self._match_r_paren(this) return self._parse_window(this) def _parse_user_defined_function(self): @@ -1920,6 +2029,18 @@ class Parser: return self.expression(exp.Identifier, this=token.text) + def _parse_session_parameter(self): + kind = None + this = self._parse_id_var() or self._parse_primary() + if self._match(TokenType.DOT): + kind = this.name + this = self._parse_var() or self._parse_primary() + return self.expression( + exp.SessionParameter, + this=this, + kind=kind, + ) + def _parse_udf_kwarg(self): this = self._parse_id_var() kind = self._parse_types() @@ -1938,27 +2059,24 @@ class Parser: else: expressions = [self._parse_id_var()] - if not self._match(TokenType.ARROW): - self._retreat(index) + if self._match_set(self.LAMBDAS): + return self.LAMBDAS[self._prev.token_type](self, expressions) - if self._match(TokenType.DISTINCT): - this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)) - else: - this = self._parse_conjunction() + self._retreat(index) - if self._match(TokenType.IGNORE_NULLS): - this = self.expression(exp.IgnoreNulls, this=this) - else: - self._match(TokenType.RESPECT_NULLS) + if self._match(TokenType.DISTINCT): + this = self.expression( + exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) + ) + else: + this = self._parse_conjunction() - return self._parse_alias(self._parse_limit(self._parse_order(this))) + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + else: + self._match(TokenType.RESPECT_NULLS) - conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions}) - return self.expression( - exp.Lambda, - this=conjunction, - expressions=expressions, - ) + return self._parse_alias(self._parse_limit(self._parse_order(this))) def _parse_schema(self, this=None): index = self._index @@ -1966,7 +2084,9 @@ class Parser: self._retreat(index) return this - args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) + args = self._parse_csv( + lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) + ) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -2104,6 +2224,7 @@ class Parser: if not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") + this.comment = self._prev_comment return self._parse_bracket(this) def _parse_case(self): @@ -2124,7 +2245,9 @@ class Parser: if not self._match(TokenType.END): self.raise_error("Expected END after CASE", self._prev) - return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) + return self._parse_window( + self.expression(exp.Case, this=expression, ifs=ifs, default=default) + ) def _parse_if(self): if self._match(TokenType.L_PAREN): @@ -2331,7 +2454,9 @@ class Parser: self._match(TokenType.BETWEEN) return { - "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) + "value": ( + self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text + ) or self._parse_bitwise(), "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } @@ -2348,7 +2473,7 @@ class Parser: this=this, expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), ) - self._match_r_paren() + self._match_r_paren(aliases) return aliases alias = self._parse_id_var(any_token) @@ -2365,28 +2490,29 @@ class Parser: return identifier if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: - return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) - - return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) + self._advance() + elif not self._match_set(tokens or self.ID_VAR_TOKENS): + return None + return exp.Identifier(this=self._prev.text, quoted=False) def _parse_string(self): if self._match(TokenType.STRING): - return exp.Literal.string(self._prev.text) + return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() def _parse_number(self): if self._match(TokenType.NUMBER): - return exp.Literal.number(self._prev.text) + return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) return self._parse_placeholder() def _parse_identifier(self): if self._match(TokenType.IDENTIFIER): - return exp.Identifier(this=self._prev.text, quoted=True) + return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() def _parse_var(self): if self._match(TokenType.VAR): - return exp.Var(this=self._prev.text) + return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() def _parse_var_or_string(self): @@ -2394,27 +2520,27 @@ class Parser: def _parse_null(self): if self._match(TokenType.NULL): - return exp.Null() + return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return None def _parse_boolean(self): if self._match(TokenType.TRUE): - return exp.Boolean(this=True) + return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): - return exp.Boolean(this=False) + return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) return None def _parse_star(self): if self._match(TokenType.STAR): - return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) + return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None def _parse_placeholder(self): if self._match(TokenType.PLACEHOLDER): - return exp.Placeholder() + return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): self._advance() - return exp.Placeholder(this=self._prev.text) + return self.expression(exp.Placeholder, this=self._prev.text) return None def _parse_except(self): @@ -2432,22 +2558,27 @@ class Parser: self._match_r_paren() return columns - def _parse_csv(self, parse): - parse_result = parse() + def _parse_csv(self, parse_method): + parse_result = parse_method() items = [parse_result] if parse_result is not None else [] while self._match(TokenType.COMMA): - parse_result = parse() + if parse_result and self._prev_comment is not None: + parse_result.comment = self._prev_comment + + parse_result = parse_method() if parse_result is not None: items.append(parse_result) return items - def _parse_tokens(self, parse, expressions): - this = parse() + def _parse_tokens(self, parse_method, expressions): + this = parse_method() while self._match_set(expressions): - this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) + this = self.expression( + expressions[self._prev.token_type], this=this, expression=parse_method() + ) return this @@ -2460,6 +2591,47 @@ class Parser: def _parse_select_or_expression(self): return self._parse_select() or self._parse_expression() + def _parse_use(self): + return self.expression(exp.Use, this=self._parse_id_var()) + + def _parse_show(self): + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) + if parser: + return parser(self) + self._advance() + return self.expression(exp.Show, this=self._prev.text.upper()) + + def _default_parse_set_item(self): + return self.expression( + exp.SetItem, + this=self._parse_statement(), + ) + + def _parse_set_item(self): + parser = self._find_parser(self.SET_PARSERS, self._set_trie) + return parser(self) if parser else self._default_parse_set_item() + + def _parse_set(self): + return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + + def _find_parser(self, parsers, trie): + index = self._index + this = [] + while True: + # The current token might be multiple words + curr = self._curr.text.upper() + key = curr.split(" ") + this.append(curr) + self._advance() + result, trie = in_trie(trie, key) + if result == 0: + break + if result == 2: + subparser = parsers[" ".join(this)] + return subparser + self._retreat(index) + return None + def _match(self, token_type): if not self._curr: return None @@ -2491,13 +2663,17 @@ class Parser: return None - def _match_l_paren(self): + def _match_l_paren(self, expression=None): if not self._match(TokenType.L_PAREN): self.raise_error("Expecting (") + if expression and self._prev_comment: + expression.comment = self._prev_comment - def _match_r_paren(self): + def _match_r_paren(self, expression=None): if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") + if expression and self._prev_comment: + expression.comment = self._prev_comment def _match_text(self, *texts): index = self._index diff --git a/sqlglot/planner.py b/sqlglot/planner.py index ea995d8..cd1de5e 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -72,7 +72,9 @@ class Step: if from_: from_ = from_.expressions if len(from_) > 1: - raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer") + raise UnsupportedError( + "Multi-from statements are unsupported. Run it through the optimizer" + ) step = Scan.from_expression(from_[0], ctes) else: @@ -102,7 +104,7 @@ class Step: continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" - operand.replace(exp.column(operands[operand], step.name, quoted=True)) + operand.replace(exp.column(operands[operand], quoted=True)) else: projections.append(e) @@ -117,9 +119,11 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name - aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) + aggregate.operands = tuple( + alias(operand, alias_) for operand, alias_ in operands.items() + ) aggregate.aggregations = aggregations - aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions] + aggregate.group = group.expressions aggregate.add_dependency(step) step = aggregate @@ -136,9 +140,6 @@ class Step: sort.key = order.expressions sort.add_dependency(step) step = sort - for k in sort.key + projections: - for column in k.find_all(exp.Column): - column.set("table", exp.to_identifier(step.name, quoted=True)) step.projections = projections @@ -203,7 +204,9 @@ class Scan(Step): alias_ = expression.alias if not alias_: - raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer") + raise UnsupportedError( + "Tables/Subqueries must be aliased. Run it through the optimizer" + ) if isinstance(expression, exp.Subquery): table = expression.this diff --git a/sqlglot/py.typed b/sqlglot/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/sqlglot/schema.py b/sqlglot/schema.py index c916330..fcf7291 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -1,44 +1,60 @@ +from __future__ import annotations + import abc +import typing as t from sqlglot import expressions as exp -from sqlglot.errors import OptimizeError +from sqlglot.errors import SchemaError from sqlglot.helper import csv_reader +from sqlglot.trie import in_trie, new_trie + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.types import StructType + + ColumnMapping = t.Union[t.Dict, str, StructType, t.List] + +TABLE_ARGS = ("this", "db", "catalog") class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def add_table(self, table, column_mapping=None): + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: """ - Register or update a table. Some implementing classes may require column information to also be provided + Register or update a table. Some implementing classes may require column information to also be provided. Args: - table (sqlglot.expressions.Table|str): Table expression instance or string representing the table - column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + table: table expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. """ @abc.abstractmethod - def column_names(self, table, only_visible=False): + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: """ Get the column names for a table. + Args: - table (sqlglot.expressions.Table): Table expression instance - only_visible (bool): Whether to include invisible columns + table: the `Table` expression instance. + only_visible: whether to include invisible columns. + Returns: - list[str]: list of column names + The list of column names. """ @abc.abstractmethod - def get_column_type(self, table, column): + def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: """ - Get the exp.DataType type of a column in the schema. + Get the :class:`sqlglot.exp.DataType` type of a column in the schema. Args: - table (sqlglot.expressions.Table): The source table. - column (sqlglot.expressions.Column): The target column. + table: the source table. + column: the target column. + Returns: - sqlglot.expressions.DataType.Type: The resulting column type. + The resulting column type. """ @@ -60,132 +76,179 @@ class MappingSchema(Schema): dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema=None, visible=None, dialect=None): + def __init__( + self, + schema: t.Optional[t.Dict] = None, + visible: t.Optional[t.Dict] = None, + dialect: t.Optional[str] = None, + ) -> None: self.schema = schema or {} - self.visible = visible + self.visible = visible or {} + self.schema_trie = self._build_trie(self.schema) self.dialect = dialect - self._type_mapping_cache = {} - self.supported_table_args = [] - self.forbidden_table_args = set() - if self.schema: - self._initialize_supported_args() + self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {} + self._supported_table_args: t.Tuple[str, ...] = tuple() @classmethod - def from_mapping_schema(cls, mapping_schema): + def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: + return MappingSchema( + schema=mapping_schema.schema, + visible=mapping_schema.visible, + dialect=mapping_schema.dialect, + ) + + def copy(self, **kwargs) -> MappingSchema: return MappingSchema( - schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect + **{ # type: ignore + "schema": self.schema.copy(), + "visible": self.visible.copy(), + "dialect": self.dialect, + **kwargs, + } ) - def copy(self, **kwargs): - return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) + @property + def supported_table_args(self): + if not self._supported_table_args and self.schema: + depth = _dict_depth(self.schema) - def add_table(self, table, column_mapping=None): + if not depth or depth == 1: # {} + self._supported_table_args = tuple() + elif 2 <= depth <= 4: + self._supported_table_args = TABLE_ARGS[: depth - 1] + else: + raise SchemaError(f"Invalid schema shape. Depth: {depth}") + + return self._supported_table_args + + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. Args: - table (sqlglot.expressions.Table|str): Table expression instance or string representing the table - column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. """ - table = exp.to_table(table) - self._validate_table(table) + table_ = self._ensure_table(table) column_mapping = ensure_column_mapping(column_mapping) - table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] - existing_column_mapping = _nested_get( - self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False - ) - if existing_column_mapping and not column_mapping: + schema = self.find_schema(table_, raise_on_missing=False) + + if schema and not column_mapping: return + _nested_set( self.schema, - [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], + list(reversed(self.table_parts(table_))), column_mapping, ) - self._initialize_supported_args() + self.schema_trie = self._build_trie(self.schema) - def _get_table_args_from_table(self, table): - if table.args.get("catalog") is not None: - return "catalog", "db", "this" - if table.args.get("db") is not None: - return "db", "this" - return ("this",) + def _ensure_table(self, table: exp.Table | str) -> exp.Table: + table_ = exp.to_table(table) - def _validate_table(self, table): - if not self.supported_table_args and isinstance(table, exp.Table): - return - for forbidden in self.forbidden_table_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - for expected in self.supported_table_args: - if not table.text(expected): - raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ") + if not table_: + raise SchemaError(f"Not a valid table '{table}'") + + return table_ + + def table_parts(self, table: exp.Table) -> t.List[str]: + return [table.text(part) for part in TABLE_ARGS if table.text(part)] + + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + table_ = self._ensure_table(table) - def column_names(self, table, only_visible=False): - table = exp.to_table(table) - if not isinstance(table.this, exp.Identifier): - return fs_get(table) + if not isinstance(table_.this, exp.Identifier): + return fs_get(table) # type: ignore - args = tuple(table.text(p) for p in self.supported_table_args) + schema = self.find_schema(table_) - for forbidden in self.forbidden_table_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + if schema is None: + raise SchemaError(f"Could not find table schema {table}") - columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) if not only_visible or not self.visible: - return columns + return list(schema) - visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) - return [col for col in columns if col in visible] + visible = self._nested_get(self.table_parts(table_), self.visible) + return [col for col in schema if col in visible] # type: ignore - def get_column_type(self, table, column): - try: - schema_type = self.schema.get(table.name, {}).get(column.name).upper() + def find_schema( + self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True + ) -> t.Optional[t.Dict[str, str]]: + parts = self.table_parts(table)[0 : len(self.supported_table_args)] + value, trie = in_trie(self.schema_trie if trie is None else trie, parts) + + if value == 0: + if raise_on_missing: + raise SchemaError(f"Cannot find schema for {table}.") + else: + return None + elif value == 1: + possibilities = flatten_schema(trie) + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + message = ", ".join(".".join(parts) for parts in possibilities) + if raise_on_missing: + raise SchemaError(f"Ambiguous schema for {table}: {message}.") + return None + + return self._nested_get(parts, raise_on_missing=raise_on_missing) + + def get_column_type( + self, table: exp.Table | str, column: exp.Column | str + ) -> exp.DataType.Type: + column_name = column if isinstance(column, str) else column.name + table_ = exp.to_table(table) + if table_: + table_schema = self.find_schema(table_) + schema_type = table_schema.get(column_name).upper() # type: ignore return self._convert_type(schema_type) - except: - raise OptimizeError(f"Failed to get type for column {column.sql()}") + raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type): + def _convert_type(self, schema_type: str) -> exp.DataType.Type: """ - Convert a type represented as a string to the corresponding exp.DataType.Type object. + Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. + Args: - schema_type (str): The type we want to convert. + schema_type: the type we want to convert. + Returns: - sqlglot.expressions.DataType.Type: The resulting expression type. + The resulting expression type. """ if schema_type not in self._type_mapping_cache: try: - self._type_mapping_cache[schema_type] = exp.maybe_parse( - schema_type, into=exp.DataType, dialect=self.dialect - ).this + expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) + if expression is None: + raise ValueError(f"Could not parse {schema_type}") + self._type_mapping_cache[schema_type] = expression.this except AttributeError: - raise OptimizeError(f"Failed to convert type {schema_type}") + raise SchemaError(f"Failed to convert type {schema_type}") return self._type_mapping_cache[schema_type] - def _initialize_supported_args(self): - if not self.supported_table_args: - depth = _dict_depth(self.schema) - - all_args = ["this", "db", "catalog"] - if not depth or depth == 1: # {} - self.supported_table_args = [] - elif 2 <= depth <= 4: - self.supported_table_args = tuple(reversed(all_args[: depth - 1])) - else: - raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + def _build_trie(self, schema: t.Dict): + return new_trie(tuple(reversed(t)) for t in flatten_schema(schema)) - self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args) + def _nested_get( + self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True + ) -> t.Optional[t.Any]: + return _nested_get( + d or self.schema, + *zip(self.supported_table_args, reversed(parts)), + raise_on_missing=raise_on_missing, + ) -def ensure_schema(schema): +def ensure_schema(schema: t.Any) -> Schema: if isinstance(schema, Schema): return schema return MappingSchema(schema) -def ensure_column_mapping(mapping): +def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): if isinstance(mapping, dict): return mapping elif isinstance(mapping, str): @@ -196,7 +259,7 @@ def ensure_column_mapping(mapping): } # Check if mapping looks like a DataFrame StructType elif hasattr(mapping, "simpleString"): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore elif isinstance(mapping, list): return {x.strip(): None for x in mapping} elif mapping is None: @@ -204,7 +267,20 @@ def ensure_column_mapping(mapping): raise ValueError(f"Invalid mapping provided: {type(mapping)}") -def fs_get(table): +def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]: + tables = [] + keys = keys or [] + depth = _dict_depth(schema) + + for k, v in schema.items(): + if depth >= 3: + tables.extend(flatten_schema(v, keys + [k])) + elif depth == 2: + tables.append(keys + [k]) + return tables + + +def fs_get(table: exp.Table) -> t.List[str]: name = table.this.name if name.upper() == "READ_CSV": @@ -214,21 +290,23 @@ def fs_get(table): raise ValueError(f"Cannot read schema for {table}") -def _nested_get(d, *path, raise_on_missing=True): +def _nested_get( + d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True +) -> t.Optional[t.Any]: """ Get a value for a nested dictionary. Args: - d (dict): dictionary - *path (tuple[str, str]): tuples of (name, key) + d: the dictionary to search. + *path: tuples of (name, key), where: `key` is the key in the dictionary to get. `name` is a string to use in the error if `key` isn't found. Returns: - The value or None if it doesn't exist + The value or None if it doesn't exist. """ for name, key in path: - d = d.get(key) + d = d.get(key) # type: ignore if d is None: if raise_on_missing: name = "table" if name == "this" else name @@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True): return d -def _nested_set(d, keys, value): +def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary - Ex: + Example: >>> _nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}} + >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} - d (dict): dictionary - keys (Iterable[str]): ordered iterable of keys that makeup path to value - value (Any): The value to set in the dictionary for the given key path + Args: + d: dictionary to update. + keys: the keys that makeup the path to `value`. + value: the value to set in the dictionary for the given key path. + + Returns: + The (possibly) updated dictionary. """ if not keys: - return + return d + if len(keys) == 1: d[keys[0]] = value - return + return d + subd = d for key in keys[:-1]: if key not in subd: subd = subd.setdefault(key, {}) else: subd = subd[key] + subd[keys[-1]] = value return d -def _dict_depth(d): +def _dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. diff --git a/sqlglot/time.py b/sqlglot/time.py index 729b50d..97726b3 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -1,9 +1,13 @@ -# the generic time format is based on python time.strftime +import typing as t + +# The generic time format is based on python time.strftime. # https://docs.python.org/3/library/time.html#time.strftime from sqlglot.trie import in_trie, new_trie -def format_time(string, mapping, trie=None): +def format_time( + string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None +) -> t.Optional[str]: """ Converts a time string given a mapping. @@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None): >>> format_time("%Y", {"%Y": "YYYY"}) 'YYYY' - mapping: Dictionary of time format to target time format - trie: Optional trie, can be passed in for performance + Args: + mapping: dictionary of time format to target time format. + trie: optional trie, can be passed in for performance. + + Returns: + The converted time string. """ if not string: return None + start = 0 end = 1 size = len(string) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 766c01a..95d84d6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from enum import auto from sqlglot.helper import AutoName @@ -27,6 +30,7 @@ class TokenType(AutoName): NOT = auto() EQ = auto() NEQ = auto() + NULLSAFE_EQ = auto() AND = auto() OR = auto() AMP = auto() @@ -36,12 +40,14 @@ class TokenType(AutoName): TILDA = auto() ARROW = auto() DARROW = auto() + FARROW = auto() + HASH = auto() HASH_ARROW = auto() DHASH_ARROW = auto() LR_ARROW = auto() - ANNOTATION = auto() DOLLAR = auto() PARAMETER = auto() + SESSION_PARAMETER = auto() SPACE = auto() BREAK = auto() @@ -73,7 +79,7 @@ class TokenType(AutoName): NVARCHAR = auto() TEXT = auto() BINARY = auto() - BYTEA = auto() + VARBINARY = auto() JSON = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() @@ -142,6 +148,7 @@ class TokenType(AutoName): DESCRIBE = auto() DETERMINISTIC = auto() DISTINCT = auto() + DISTINCT_FROM = auto() DISTRIBUTE_BY = auto() DIV = auto() DROP = auto() @@ -238,6 +245,7 @@ class TokenType(AutoName): RETURNS = auto() RIGHT = auto() RLIKE = auto() + ROLLBACK = auto() ROLLUP = auto() ROW = auto() ROWS = auto() @@ -287,37 +295,49 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col") + __slots__ = ("token_type", "text", "line", "col", "comment") @classmethod - def number(cls, number): + def number(cls, number: int) -> Token: + """Returns a NUMBER token with `number` as its text.""" return cls(TokenType.NUMBER, str(number)) @classmethod - def string(cls, string): + def string(cls, string: str) -> Token: + """Returns a STRING token with `string` as its text.""" return cls(TokenType.STRING, string) @classmethod - def identifier(cls, identifier): + def identifier(cls, identifier: str) -> Token: + """Returns an IDENTIFIER token with `identifier` as its text.""" return cls(TokenType.IDENTIFIER, identifier) @classmethod - def var(cls, var): + def var(cls, var: str) -> Token: + """Returns an VAR token with `var` as its text.""" return cls(TokenType.VAR, var) - def __init__(self, token_type, text, line=1, col=1): + def __init__( + self, + token_type: TokenType, + text: str, + line: int = 1, + col: int = 1, + comment: t.Optional[str] = None, + ) -> None: self.token_type = token_type self.text = text self.line = line self.col = max(col - len(text), 1) + self.comment = comment - def __repr__(self): + def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) return f"" class _Tokenizer(type): - def __new__(cls, clsname, bases, attrs): + def __new__(cls, clsname, bases, attrs): # type: ignore klass = super().__new__(cls, clsname, bases, attrs) klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) @@ -325,27 +345,29 @@ class _Tokenizer(type): klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) + klass._ESCAPES = set(klass.ESCAPES) klass._COMMENTS = dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) + for comment in klass.COMMENTS ) klass.KEYWORD_TRIE = new_trie( key.upper() - for key, value in { + for key in { **klass.KEYWORDS, **{comment: TokenType.COMMENT for comment in klass._COMMENTS}, **{quote: TokenType.QUOTE for quote in klass._QUOTES}, **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS}, - }.items() + } if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) return klass @staticmethod - def _delimeter_list_to_dict(list): + def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) @@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer): "*": TokenType.STAR, "~": TokenType.TILDA, "?": TokenType.PLACEHOLDER, - "#": TokenType.ANNOTATION, "@": TokenType.PARAMETER, # used for breaking a var like x'y' but nothing else # the token type doesn't matter "'": TokenType.QUOTE, "`": TokenType.IDENTIFIER, '"': TokenType.IDENTIFIER, + "#": TokenType.HASH, } - QUOTES = ["'"] + QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] - BIT_STRINGS = [] + BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] - HEX_STRINGS = [] + HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] - BYTE_STRINGS = [] + BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - IDENTIFIERS = ['"'] + IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - ESCAPE = "'" + ESCAPES = ["'"] KEYWORDS = { "/*+": TokenType.HINT, @@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer): "<=": TokenType.LTE, "<>": TokenType.NEQ, "!=": TokenType.NEQ, + "<=>": TokenType.NULLSAFE_EQ, "->": TokenType.ARROW, "->>": TokenType.DARROW, + "=>": TokenType.FARROW, "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, @@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer): "DESCRIBE": TokenType.DESCRIBE, "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, + "DISTINCT FROM": TokenType.DISTINCT_FROM, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DIV": TokenType.DIV, "DROP": TokenType.DROP, @@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer): "RETURNS": TokenType.RETURNS, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, + "ROLLBACK": TokenType.ROLLBACK, "ROLLUP": TokenType.ROLLUP, "ROW": TokenType.ROW, "ROWS": TokenType.ROWS, @@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer): "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, "BINARY": TokenType.BINARY, - "BLOB": TokenType.BINARY, - "BYTEA": TokenType.BINARY, + "BLOB": TokenType.VARBINARY, + "BYTEA": TokenType.VARBINARY, + "VARBINARY": TokenType.VARBINARY, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, @@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.SET, TokenType.SHOW, TokenType.TRUNCATE, - TokenType.USE, TokenType.VACUUM, + TokenType.ROLLBACK, } # handle numeric literals like in hive (3L = BIGINT) - NUMERIC_LITERALS = {} - ENCODE = None + NUMERIC_LITERALS: t.Dict[str, str] = {} + ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/")] KEYWORD_TRIE = None # autofilled @@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer): "_current", "_line", "_col", + "_comment", "_char", "_end", "_peek", + "_prev_token_line", + "_prev_token_comment", "_prev_token_type", + "_replace_backslash", ) - def __init__(self): - """ - Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token` - """ + def __init__(self) -> None: + self._replace_backslash = "\\" in self._ESCAPES # type: ignore self.reset() - def reset(self): + def reset(self) -> None: self.sql = "" self.size = 0 - self.tokens = [] + self.tokens: t.List[Token] = [] self._start = 0 self._current = 0 self._line = 1 self._col = 1 + self._comment = None self._char = None self._end = None self._peek = None + self._prev_token_line = -1 + self._prev_token_comment = None self._prev_token_type = None - def tokenize(self, sql): + def tokenize(self, sql: str) -> t.List[Token]: + """Returns a list of tokens corresponding to the SQL string `sql`.""" self.reset() self.sql = sql self.size = len(sql) @@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer): if not self._char: break - white_space = self.WHITE_SPACE.get(self._char) - identifier_end = self._IDENTIFIERS.get(self._char) + white_space = self.WHITE_SPACE.get(self._char) # type: ignore + identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore if white_space: if white_space == TokenType.BREAK: self._col = 1 self._line += 1 - elif self._char.isdigit(): + elif self._char.isdigit(): # type:ignore self._scan_number() elif identifier_end: self._scan_identifier(identifier_end) @@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer): self._scan_keywords() return self.tokens - def _chars(self, size): + def _chars(self, size: int) -> str: if size == 1: - return self._char + return self._char # type: ignore start = self._current - 1 end = start + size if end <= self.size: return self.sql[start:end] return "" - def _advance(self, i=1): + def _advance(self, i: int = 1) -> None: self._col += i self._current += i - self._end = self._current >= self.size - self._char = self.sql[self._current - 1] - self._peek = self.sql[self._current] if self._current < self.size else "" + self._end = self._current >= self.size # type: ignore + self._char = self.sql[self._current - 1] # type: ignore + self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore @property - def _text(self): + def _text(self) -> str: return self.sql[self._start : self._current] - def _add(self, token_type, text=None): - self._prev_token_type = token_type - self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col)) + def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: + self._prev_token_line = self._line + self._prev_token_comment = self._comment + self._prev_token_type = token_type # type: ignore + self.tokens.append( + Token( + token_type, + self._text if text is None else text, + self._line, + self._col, + self._comment, + ) + ) + self._comment = None - if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): + if token_type in self.COMMANDS and ( + len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON + ): self._start = self._current while not self._end and self._peek != ";": self._advance() if self._start < self._current: self._add(TokenType.STRING) - def _scan_keywords(self): + def _scan_keywords(self) -> None: size = 0 word = None chars = self._text @@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer): if skip: result = 1 else: - result, trie = in_trie(trie, char.upper()) + result, trie = in_trie(trie, char.upper()) # type: ignore if result == 0: break @@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer): else: skip = True else: - chars = None + chars = None # type: ignore if not word: if self._char in self.SINGLE_TOKENS: - token = self.SINGLE_TOKENS[self._char] - if token == TokenType.ANNOTATION: - self._scan_annotation() - return - self._add(token) + self._add(self.SINGLE_TOKENS[self._char]) # type: ignore return self._scan_var() return @@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer): self._advance(size - 1) self._add(self.KEYWORDS[word.upper()]) - def _scan_comment(self, comment_start): - if comment_start not in self._COMMENTS: + def _scan_comment(self, comment_start: str) -> bool: + if comment_start not in self._COMMENTS: # type: ignore return False - comment_end = self._COMMENTS[comment_start] + comment_start_line = self._line + comment_start_size = len(comment_start) + comment_end = self._COMMENTS[comment_start] # type: ignore if comment_end: comment_end_size = len(comment_end) while not self._end and self._chars(comment_end_size) != comment_end: self._advance() + + self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore self._advance(comment_end_size - 1) else: - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore self._advance() - return True + self._comment = self._text[comment_start_size:] # type: ignore - def _scan_annotation(self): - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",": - self._advance() - self._add(TokenType.ANNOTATION, self._text[1:]) + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both + # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one. - def _scan_number(self): + if comment_start_line == self._prev_token_line: + if self._prev_token_comment is None: + self.tokens[-1].comment = self._comment + + self._comment = None + + return True + + def _scan_number(self) -> None: if self._char == "0": - peek = self._peek.upper() + peek = self._peek.upper() # type: ignore if peek == "B": return self._scan_bits() elif peek == "X": @@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer): scientific = 0 while True: - if self._peek.isdigit(): + if self._peek.isdigit(): # type: ignore self._advance() elif self._peek == "." and not decimal: decimal = True @@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() - elif self._peek.upper() == "E" and not scientific: + elif self._peek.upper() == "E" and not scientific: # type: ignore scientific += 1 self._advance() - elif self._peek.isalpha(): + elif self._peek.isalpha(): # type: ignore self._add(TokenType.NUMBER) literal = [] - while self._peek.isalpha(): - literal.append(self._peek.upper()) + while self._peek.isalpha(): # type: ignore + literal.append(self._peek.upper()) # type: ignore self._advance() - literal = "".join(literal) - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) + literal = "".join(literal) # type: ignore + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore if token_type: self._add(TokenType.DCOLON, "::") - return self._add(token_type, literal) + return self._add(token_type, literal) # type: ignore return self._advance(-len(literal)) else: return self._add(TokenType.NUMBER) - def _scan_bits(self): + def _scan_bits(self) -> None: self._advance() value = self._extract_value() try: @@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer): except ValueError: self._add(TokenType.IDENTIFIER) - def _scan_hex(self): + def _scan_hex(self) -> None: self._advance() value = self._extract_value() try: @@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer): except ValueError: self._add(TokenType.IDENTIFIER) - def _extract_value(self): + def _extract_value(self) -> str: while True: - char = self._peek.strip() + char = self._peek.strip() # type: ignore if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer): return self._text - def _scan_string(self, quote): - quote_end = self._QUOTES.get(quote) + def _scan_string(self, quote: str) -> bool: + quote_end = self._QUOTES.get(quote) # type: ignore if quote_end is None: return False self._advance(len(quote)) text = self._extract_string(quote_end) - - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text - text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore + text = text.replace("\\\\", "\\") if self._replace_backslash else text self._add(TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. - def _scan_formatted_string(self, string_start): - if string_start in self._HEX_STRINGS: - delimiters = self._HEX_STRINGS + def _scan_formatted_string(self, string_start: str) -> bool: + if string_start in self._HEX_STRINGS: # type: ignore + delimiters = self._HEX_STRINGS # type: ignore token_type = TokenType.HEX_STRING base = 16 - elif string_start in self._BIT_STRINGS: - delimiters = self._BIT_STRINGS + elif string_start in self._BIT_STRINGS: # type: ignore + delimiters = self._BIT_STRINGS # type: ignore token_type = TokenType.BIT_STRING base = 2 - elif string_start in self._BYTE_STRINGS: - delimiters = self._BYTE_STRINGS + elif string_start in self._BYTE_STRINGS: # type: ignore + delimiters = self._BYTE_STRINGS # type: ignore token_type = TokenType.BYTE_STRING base = None else: @@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer): try: self._add(token_type, f"{int(text, base)}") except: - raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + raise RuntimeError( + f"Numeric string contains invalid characters from {self._line}:{self._start}" + ) return True - def _scan_identifier(self, identifier_end): + def _scan_identifier(self, identifier_end: str) -> None: while self._peek != identifier_end: if self._end: raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") @@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() self._add(TokenType.IDENTIFIER, self._text[1:-1]) - def _scan_var(self): + def _scan_var(self) -> None: while True: - char = self._peek.strip() + char = self._peek.strip() # type: ignore if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer): else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) ) - def _extract_string(self, delimiter): + def _extract_string(self, delimiter: str) -> str: text = "" delim_size = len(delimiter) while True: - if self._char == self.ESCAPE and self._peek == delimiter: + if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore text += delimiter self._advance(2) else: @@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") - text += self._char + text += self._char # type: ignore self._advance() return text diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 014ae00..412b881 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -1,7 +1,14 @@ +from __future__ import annotations + +import typing as t + +if t.TYPE_CHECKING: + from sqlglot.generator import Generator + from sqlglot import expressions as exp -def unalias_group(expression): +def unalias_group(expression: exp.Expression) -> exp.Expression: """ Replace references to select aliases in GROUP BY clauses. @@ -9,6 +16,12 @@ def unalias_group(expression): >>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1' + + Args: + expression: the expression that will be transformed. + + Returns: + The transformed expression. """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { @@ -30,19 +43,20 @@ def unalias_group(expression): return expression -def preprocess(transforms, to_sql): +def preprocess( + transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], + to_sql: t.Callable[[Generator, exp.Expression], str], +) -> t.Callable[[Generator, exp.Expression], str]: """ - Create a new transform function that can be used a value in `Generator.TRANSFORMS` - to convert expressions to SQL. + Creates a new transform by chaining a sequence of transformations and converts the resulting + expression to SQL, using an appropriate `Generator.TRANSFORMS` function. Args: - transforms (list[(exp.Expression) -> exp.Expression]): - Sequence of transform functions. These will be called in order. - to_sql ((sqlglot.generator.Generator, exp.Expression) -> str): - Final transform that converts the resulting expression to a SQL string. + transforms: sequence of transform functions. These will be called in order. + to_sql: final transform that converts the resulting expression to a SQL string. + Returns: - (sqlglot.generator.Generator, exp.Expression) -> str: - Function that can be used as a generator transform. + Function that can be used as a generator transform. """ def _to_sql(self, expression): @@ -54,12 +68,10 @@ def preprocess(transforms, to_sql): return _to_sql -def delegate(attr): +def delegate(attr: str) -> t.Callable: """ - Create a new method that delegates to `attr`. - - This is useful for creating `Generator.TRANSFORMS` functions that delegate - to existing generator methods. + Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS` + functions that delegate to existing generator methods. """ def _transform(self, *args, **kwargs): diff --git a/sqlglot/trie.py b/sqlglot/trie.py index a234107..fa2aaf1 100644 --- a/sqlglot/trie.py +++ b/sqlglot/trie.py @@ -1,5 +1,26 @@ -def new_trie(keywords): - trie = {} +import typing as t + +key = t.Sequence[t.Hashable] + + +def new_trie(keywords: t.Iterable[key]) -> t.Dict: + """ + Creates a new trie out of a collection of keywords. + + The trie is represented as a sequence of nested dictionaries keyed by either single character + strings, or by 0, which is used to designate that a keyword is in the trie. + + Example: + >>> new_trie(["bla", "foo", "blab"]) + {'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}} + + Args: + keywords: the keywords to create the trie from. + + Returns: + The trie corresponding to `keywords`. + """ + trie: t.Dict = {} for key in keywords: current = trie @@ -11,7 +32,28 @@ def new_trie(keywords): return trie -def in_trie(trie, key): +def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]: + """ + Checks whether a key is in a trie. + + Examples: + >>> in_trie(new_trie(["cat"]), "bob") + (0, {'c': {'a': {'t': {0: True}}}}) + + >>> in_trie(new_trie(["cat"]), "ca") + (1, {'t': {0: True}}) + + >>> in_trie(new_trie(["cat"]), "cat") + (2, {0: True}) + + Args: + trie: the trie to be searched. + key: the target key. + + Returns: + A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value` + is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). + """ if not key: return (0, trie) -- cgit v1.2.3