diff options
Diffstat (limited to '')
98 files changed, 4070 insertions, 1656 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fc508f..87dd21d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,29 @@ Changelog ========= +v10.0.0 +------ + +Changes: + +- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions. +- Breaking: renamed list_get to seq_get. +- Breaking: activated mypy type checking for SQLGlot. +- New: Azure Databricks support. +- New: placeholders can now be replaced in an expression. +- New: null safe equal operator (<=>). +- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL. +- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL. +- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL. +- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL. +- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16) +- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16). +- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities. +- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible. +- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor. +- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636). +- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause. + v9.0.0 ------ @@ -14,7 +14,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/ * [Install](#install) * [Documentation](#documentation) -* [Run Tests & Lint](#run-tests-and-lint) +* [Run Tests and Lint](#run-tests-and-lint) * [Examples](#examples) * [Formatting and Transpiling](#formatting-and-transpiling) * [Metadata](#metadata) @@ -22,7 +22,6 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/ * [Unsupported Errors](#unsupported-errors) * [Build and Modify SQL](#build-and-modify-sql) * [SQL Optimizer](#sql-optimizer) - * [SQL Annotations](#sql-annotations) * [AST Introspection](#ast-introspection) * [AST Diff](#ast-diff) * [Custom Dialects](#custom-dialects) @@ -51,7 +50,7 @@ pip3 install -r dev-requirements.txt ## Documentation -SQLGlot's uses [pdocs](https://pdoc.dev/) to serve its API documentation: +SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: ``` pdoc sqlglot --docformat google @@ -121,6 +120,39 @@ LEFT JOIN `baz` ON `f`.`a` = `baz`.`a` ``` +Comments are also preserved in a best-effort basis when transpiling SQL code: + +```python +sql = """ +/* multi + line + comment +*/ +SELECT + tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, + CAST(x AS INT), # comment 3 + y -- comment 4 +FROM + bar /* comment 5 */, + tbl # comment 6 +""" + +print(sqlglot.transpile(sql, read='mysql', pretty=True)[0]) +``` + +```sql +/* multi + line + comment +*/ +SELECT + tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, + CAST(x AS INT), -- comment 3 + y -- comment 4 +FROM bar /* comment 5 */, tbl /* comment 6*/ +``` + + ### Metadata You can explore SQL with expression helpers to do things like find columns and tables: @@ -249,17 +281,6 @@ WHERE "x"."Z" = CAST('2021-02-01' AS DATE) ``` -### SQL Annotations - -SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example: - -```sql -SELECT - user # primary_key, - country -FROM users -``` - ### AST Introspection You can see the AST version of the sql by calling `repr`: diff --git a/run_checks.sh b/run_checks.sh index b13a61c..187e6b9 100755 --- a/run_checks.sh +++ b/run_checks.sh @@ -1,15 +1,8 @@ #!/bin/bash -e - [[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check' - -python -m autoflake -i -r ${RETURN_ERROR_CODE} \ - --expand-star-imports \ - --remove-all-unused-imports \ - --ignore-init-module-imports \ - --remove-duplicate-keys \ - --remove-unused-variables \ - sqlglot/ tests/ -python -m isort --profile black sqlglot/ tests/ -python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/ -python -m mypy sqlglot tests +TARGETS="sqlglot/ tests/" +python -m mypy $TARGETS +python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS +python -m isort $TARGETS +python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS python -m unittest @@ -3,7 +3,7 @@ disallow_untyped_calls = False no_implicit_optional = True [mypy-sqlglot.*] -ignore_errors = True +ignore_errors = False [mypy-sqlglot.dataframe.*] ignore_errors = False @@ -13,3 +13,16 @@ ignore_errors = True [mypy-tests.dataframe.*] ignore_errors = False + +[autoflake] +in-place = True +expand-star-imports = True +remove-all-unused-imports = True +ignore-init-module-imports = True +remove-duplicate-keys = True +remove-unused-variables = True +quiet = True + +[isort] +profile=black +known_first_party=sqlglot @@ -21,6 +21,7 @@ setup( author_email="toby.mao@gmail.com", license="MIT", packages=find_packages(include=["sqlglot", "sqlglot.*"]), + package_data={"sqlglot": ["py.typed"]}, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index d6e18fd..6e67b19 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,5 +1,9 @@ """## Python SQL parser, transpiler and optimizer.""" +from __future__ import annotations + +import typing as t + from sqlglot import expressions as exp from sqlglot.dialects import Dialect, Dialects from sqlglot.diff import diff @@ -20,51 +24,54 @@ from sqlglot.expressions import ( subquery, ) from sqlglot.expressions import table_ as table -from sqlglot.expressions import union +from sqlglot.expressions import to_column, to_table, union from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "9.0.6" +__version__ = "10.0.1" pretty = False schema = MappingSchema() -def parse(sql, read=None, **opts): +def parse( + sql: str, read: t.Optional[str | Dialect] = None, **opts +) -> t.List[t.Optional[Expression]]: """ - Parses the given SQL string into a collection of syntax trees, one per - parsed SQL statement. + Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. Args: - sql (str): the SQL code string to parse. - read (str): the SQL dialect to apply during parsing - (eg. "spark", "hive", "presto", "mysql"). + sql: the SQL code string to parse. + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). **opts: other options. Returns: - typing.List[Expression]: the list of parsed syntax trees. + The resulting syntax tree collection. """ dialect = Dialect.get_or_raise(read)() return dialect.parse(sql, **opts) -def parse_one(sql, read=None, into=None, **opts): +def parse_one( + sql: str, + read: t.Optional[str | Dialect] = None, + into: t.Optional[Expression | str] = None, + **opts, +) -> t.Optional[Expression]: """ - Parses the given SQL string and returns a syntax tree for the first - parsed SQL statement. + Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. Args: - sql (str): the SQL code string to parse. - read (str): the SQL dialect to apply during parsing - (eg. "spark", "hive", "presto", "mysql"). - into (Expression): the SQLGlot Expression to parse into + sql: the SQL code string to parse. + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + into: the SQLGlot Expression to parse into. **opts: other options. Returns: - Expression: the syntax tree for the first parsed statement. + The syntax tree for the first parsed statement. """ dialect = Dialect.get_or_raise(read)() @@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts): return result[0] if result else None -def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts): +def transpile( + sql: str, + read: t.Optional[str | Dialect] = None, + write: t.Optional[str | Dialect] = None, + identity: bool = True, + error_level: t.Optional[ErrorLevel] = None, + **opts, +) -> t.List[str]: """ - Parses the given SQL string using the source dialect and returns a list of SQL strings - transformed to conform to the target dialect. Each string in the returned list represents - a single transformed SQL statement. + Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed + to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement. Args: - sql (str): the SQL code string to transpile. - read (str): the source dialect used to parse the input string - (eg. "spark", "hive", "presto", "mysql"). - write (str): the target dialect into which the input should be transformed - (eg. "spark", "hive", "presto", "mysql"). - identity (bool): if set to True and if the target dialect is not specified - the source dialect will be used as both: the source and the target dialect. - error_level (ErrorLevel): the desired error level of the parser. + sql: the SQL code string to transpile. + read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql"). + write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql"). + identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both: + the source and the target dialect. + error_level: the desired error level of the parser. **opts: other options. Returns: - typing.List[str]: the list of transpiled SQL statements / expressions. + The list of transpiled SQL statements. """ write = write or read if identity else write return [ diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index c0fa380..42a54bc 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -49,7 +49,10 @@ args = parser.parse_args() error_level = sqlglot.ErrorLevel[args.error_level.upper()] if args.parse: - sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)] + sqls = [ + repr(expression) + for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level) + ] else: sqls = sqlglot.transpile( args.sql, diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi index f1a03ea..67c8c09 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -10,11 +10,17 @@ if t.TYPE_CHECKING: from sqlglot.dataframe.sql.types import StructType ColumnLiterals = t.TypeVar( - "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] + "ColumnLiterals", + bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], ) ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) ColumnOrLiteral = t.TypeVar( - "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] + "ColumnOrLiteral", + bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], +) +SchemaInput = t.TypeVar( + "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]] +) +OutputExpressionContainer = t.TypeVar( + "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert] ) -SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]) -OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]) diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index e66aaa8..f9e1c5b 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -18,7 +18,11 @@ class Column: expression = expression.expression # type: ignore elif expression is None or not isinstance(expression, (str, exp.Expression)): expression = self._lit(expression).expression # type: ignore - self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + expression = sqlglot.maybe_parse(expression, dialect="spark") + if expression is None: + raise ValueError(f"Could not parse {expression}") + self.expression: exp.Expression = expression def __repr__(self): return repr(self.expression) @@ -135,21 +139,29 @@ class Column: ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression + k: [Column.ensure_col(x).expression for x in v] + if is_iterable(v) + else Column.ensure_col(v).expression for k, v in kwargs.items() } new_expression = ( callable_expression(**ensure_expression_values) if ensured_column is None - else callable_expression(this=ensured_column.column_expression, **ensure_expression_values) + else callable_expression( + this=ensured_column.column_expression, **ensure_expression_values + ) ) return Column(new_expression) def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + return Column( + klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) + ) def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + return Column( + klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) + ) def unary_op(self, klass: t.Callable, **kwargs) -> Column: return Column(klass(this=self.column_expression, **kwargs)) @@ -188,7 +200,7 @@ class Column: expression.set("table", exp.to_identifier(table_name)) return Column(expression) - def sql(self, **kwargs) -> Column: + def sql(self, **kwargs) -> str: return self.expression.sql(**{"dialect": "spark", **kwargs}) def alias(self, name: str) -> Column: @@ -265,10 +277,14 @@ class Column: ) def like(self, other: str): - return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.Like, expression=self._lit(other).expression + ) def ilike(self, other: str): - return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.ILike, expression=self._lit(other).expression + ) def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos @@ -287,10 +303,18 @@ class Column: lowerBound: t.Union[ColumnOrLiteral], upperBound: t.Union[ColumnOrLiteral], ) -> Column: - lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + lower_bound_exp = ( + self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + ) + upper_bound_exp = ( + self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + ) return Column( - exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + exp.Between( + this=self.column_expression, + low=lower_bound_exp.expression, + high=upper_bound_exp.expression, + ) ) def over(self, window: WindowSpec) -> Column: diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 322dcf2..40cd6c9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func from sqlglot.optimizer.qualify_columns import qualify_columns if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql._typing import ( + ColumnLiterals, + ColumnOrLiteral, + ColumnOrName, + OutputExpressionContainer, + ) from sqlglot.dataframe.sql.session import SparkSession @@ -83,7 +88,9 @@ class DataFrame: return from_exp.alias_or_name table_alias = from_exp.find(exp.TableAlias) if not table_alias: - raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + raise RuntimeError( + f"Could not find an alias name for this expression: {self.expression}" + ) return table_alias.alias_or_name return self.expression.ctes[-1].alias @@ -132,12 +139,16 @@ class DataFrame: cte.set("sequence_id", sequence_id or self.sequence_id) return cte, name - def _ensure_list_of_columns( - self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] - ) -> t.List[Column]: - columns = ensure_list(cols) - columns = Column.ensure_cols(columns) - return columns + @t.overload + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: + ... + + @t.overload + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: + ... + + def _ensure_list_of_columns(self, cols): + return Column.ensure_cols(ensure_list(cols)) def _ensure_and_normalize_cols(self, cols): cols = self._ensure_list_of_columns(cols) @@ -153,10 +164,16 @@ class DataFrame: df = self._resolve_pending_hints() sequence_id = sequence_id or df.sequence_id expression = df.expression.copy() - cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) - new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + cte_expression, cte_name = df._create_cte_from_expression( + expression=expression, sequence_id=sequence_id + ) + new_expression = df._add_ctes_to_expression( + exp.Select(), expression.ctes + [cte_expression] + ) sel_columns = df._get_outer_select_columns(cte_expression) - new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + new_expression = new_expression.from_(cte_name).select( + *[x.alias_or_name for x in sel_columns] + ) return df.copy(expression=new_expression, sequence_id=sequence_id) def _resolve_pending_hints(self) -> DataFrame: @@ -169,16 +186,23 @@ class DataFrame: hint_expression.args.get("expressions").append(hint) df.pending_hints.remove(hint) - join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + join_aliases = { + join_table.alias_or_name + for join_table in get_tables_from_expression_with_join(expression) + } if join_aliases: for hint in df.pending_join_hints: for sequence_id_expression in hint.expressions: sequence_id_or_name = sequence_id_expression.alias_or_name sequence_ids_to_match = [sequence_id_or_name] if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: - sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ + sequence_id_or_name + ] matching_ctes = [ - cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + cte + for cte in reversed(expression.ctes) + if cte.args["sequence_id"] in sequence_ids_to_match ] for matching_cte in matching_ctes: if matching_cte.alias_or_name in join_aliases: @@ -193,9 +217,14 @@ class DataFrame: def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: hint_name = hint_name.upper() hint_expression = ( - exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + exp.JoinHint( + this=hint_name, + expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], + ) if hint_name in JOIN_HINTS - else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + else exp.Anonymous( + this=hint_name, expressions=[parameter.expression for parameter in args] + ) ) new_df = self.copy() new_df.pending_hints.append(hint_expression) @@ -245,7 +274,9 @@ class DataFrame: def _get_select_expressions( self, ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: - select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + select_expressions: t.List[ + t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] + ] = [] main_select_ctes: t.List[exp.CTE] = [] for cte in self.expression.ctes: cache_storage_level = cte.args.get("cache_storage_level") @@ -279,14 +310,19 @@ class DataFrame: cache_table_name = df._create_hash_from_expression(select_expression) cache_table = exp.to_table(cache_table_name) original_alias_name = select_expression.args["cte_alias_name"] - replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore + cache_table_name + ) sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), exp.Literal.string(cache_storage_level), ] - expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + expression = exp.Cache( + this=cache_table, expression=select_expression, lazy=True, options=options + ) # We will drop the "view" if it exists before running the cache table output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) elif expression_type == exp.Create: @@ -305,7 +341,9 @@ class DataFrame: raise ValueError(f"Invalid expression type: {expression_type}") output_expressions.append(expression) - return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + return [ + expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + ] def copy(self, **kwargs) -> DataFrame: return DataFrame(**object_to_dict(self, **kwargs)) @@ -317,7 +355,9 @@ class DataFrame: if self.expression.args.get("joins"): ambiguous_cols = [col for col in cols if not col.column_expression.table] if ambiguous_cols: - join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + join_table_identifiers = [ + x.this for x in get_tables_from_expression_with_join(self.expression) + ] cte_names_in_join = [x.this for x in join_table_identifiers] for ambiguous_col in ambiguous_cols: ctes_with_column = [ @@ -367,14 +407,20 @@ class DataFrame: @operation(Operation.FROM) def join( - self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + self, + other_df: DataFrame, + on: t.Union[str, t.List[str], Column, t.List[Column]], + how: str = "inner", + **kwargs, ) -> DataFrame: other_df = other_df._convert_leaf_to_cte() pre_join_self_latest_cte_name = self.latest_cte_name columns = self._ensure_and_normalize_cols(on) join_type = how.replace("_", " ") if isinstance(columns[0].expression, exp.Column): - join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns + ] join_clause = functools.reduce( lambda x, y: x & y, [ @@ -402,7 +448,9 @@ class DataFrame: for column in self._get_outer_select_columns(other_df) ] column_value_mapping = { - column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql(): column for column in other_columns + self_columns + join_columns } all_columns = [ @@ -410,16 +458,22 @@ class DataFrame: for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} ] new_df = self.copy( - expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + expression=self.expression.join( + other_df.latest_cte_name, on=join_clause.expression, join_type=join_type + ) + ) + new_df.expression = new_df._add_ctes_to_expression( + new_df.expression, other_df.expression.ctes ) - new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) new_df.pending_hints.extend(other_df.pending_hints) new_df = new_df.select.__wrapped__(new_df, *all_columns) return new_df @operation(Operation.ORDER_BY) def orderBy( - self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + self, + *cols: t.Union[str, Column], + ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, ) -> DataFrame: """ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark @@ -429,7 +483,10 @@ class DataFrame: columns = self._ensure_and_normalize_cols(cols) pre_ordered_col_indexes = [ x - for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + for x in [ + i if isinstance(col.expression, exp.Ordered) else None + for i, col in enumerate(columns) + ] if x is not None ] if ascending is None: @@ -478,7 +535,9 @@ class DataFrame: for r_column in r_columns_unused: l_expressions.append(exp.alias_(exp.Null(), r_column)) r_expressions.append(r_column) - r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + r_df = ( + other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + ) l_df = self.copy() if allowMissingColumns: l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) @@ -536,7 +595,9 @@ class DataFrame: f"The minimum num nulls for dropna must be less than or equal to the number of columns. " f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" ) - if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + if_null_checks = [ + F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns + ] nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) num_nulls = nulls_added_together.alias("num_nulls") new_df = new_df.select(num_nulls, append=True) @@ -576,11 +637,15 @@ class DataFrame: value_columns = [lit(value) for value in values] null_replacement_mapping = { - column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + column.alias_or_name: ( + F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) + ) for column, value in zip(columns, value_columns) } null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} - null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + null_replacement_columns = [ + null_replacement_mapping[column.alias_or_name] for column in all_columns + ] new_df = new_df.select(*null_replacement_columns) return new_df @@ -589,12 +654,11 @@ class DataFrame: self, to_replace: t.Union[bool, int, float, str, t.List, t.Dict], value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Union[str, t.List[str]]] = None, + subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, ) -> DataFrame: from sqlglot.dataframe.sql.functions import lit old_values = None - subset = ensure_list(subset) new_df = self.copy() all_columns = self._get_outer_select_columns(new_df.expression) all_column_mapping = {column.alias_or_name: column for column in all_columns} @@ -605,7 +669,9 @@ class DataFrame: new_values = list(to_replace.values()) elif not old_values and isinstance(to_replace, list): assert isinstance(value, list), "value must be a list since the replacements are a list" - assert len(to_replace) == len(value), "the replacements and values must be the same length" + assert len(to_replace) == len( + value + ), "the replacements and values must be the same length" old_values = to_replace new_values = value else: @@ -635,7 +701,9 @@ class DataFrame: def withColumn(self, colName: str, col: Column) -> DataFrame: col = self._ensure_and_normalize_col(col) existing_col_names = self.expression.named_selects - existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + existing_col_index = ( + existing_col_names.index(colName) if colName in existing_col_names else None + ) if existing_col_index: expression = self.expression.copy() expression.expressions[existing_col_index] = col.expression @@ -645,7 +713,11 @@ class DataFrame: @operation(Operation.SELECT) def withColumnRenamed(self, existing: str, new: str): expression = self.expression.copy() - existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + existing_columns = [ + expression + for expression in expression.expressions + if expression.alias_or_name == existing + ] if not existing_columns: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: @@ -674,15 +746,19 @@ class DataFrame: def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: parameter_list = ensure_list(parameters) parameter_columns = ( - self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + self._ensure_list_of_columns(parameter_list) + if parameters + else Column.ensure_cols([self.sequence_id]) ) return self._hint(name, parameter_columns) @operation(Operation.NO_OP) - def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: - num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + def repartition( + self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName + ) -> DataFrame: + num_partition_cols = self._ensure_list_of_columns(numPartitions) columns = self._ensure_and_normalize_cols(cols) - args = num_partitions + columns + args = num_partition_cols + columns return self._hint("repartition", args) @operation(Operation.NO_OP) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bc002e5..dbfb06f 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def when(condition: Column, value: t.Any) -> Column: true_value = value if isinstance(value, Column) else lit(value) - return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) + return Column( + glotexp.Case( + ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)] + ) + ) def asc(col: ColumnOrName) -> Column: @@ -407,7 +411,9 @@ def percentile_approx( return Column.invoke_expression_over_column( col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy ) - return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage)) + return Column.invoke_expression_over_column( + col, glotexp.ApproxQuantile, quantile=lit(percentage) + ) def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: @@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "FACTORIAL") -def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: +def lag( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LAG", offset, default) if offset != 1: @@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu return Column.invoke_anonymous_function(col, "LAG") -def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: +def lead( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LEAD", offset, default) if offset != 1: @@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A return Column.invoke_anonymous_function(col, "LEAD") -def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: +def nth_value( + col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None +) -> Column: if ignoreNulls is not None: raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") if offset != 1: @@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) -def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: +def months_between( + date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None +) -> Column: if roundOff is None: return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) @@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: return Column.invoke_expression_over_column(col, glotexp.UnixToStr) -def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: +def unix_timestamp( + timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None +) -> Column: if format is not None: - return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format)) + return Column.invoke_expression_over_column( + timestamp, glotexp.StrToUnix, format=lit(format) + ) return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) @@ -642,7 +660,9 @@ def window( timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) ) if slideDuration is not None: - return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration) + ) if startTime is not None: return Column.invoke_anonymous_function( timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) @@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)) + return Column.invoke_expression_over_column( + None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols) + ) def decode(col: ColumnOrName, charset: str) -> Column: @@ -768,7 +790,9 @@ def overlay( def sentences( - string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None + string: ColumnOrName, + language: t.Optional[ColumnOrName] = None, + country: t.Optional[ColumnOrName] = None, ) -> Column: if language is not None and country is not None: return Column.invoke_anonymous_function(string, "SENTENCES", language, country) @@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) if pos is not None: - return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos) + return Column.invoke_expression_over_column( + str, glotexp.StrPosition, substr=substr_col, position=pos + ) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) @@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore return Column.invoke_expression_over_column( - None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression + None, + glotexp.VarMap, + keys=array(*cols[::2]).expression, + values=array(*cols[1::2]).expression, ) @@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) + return Column.invoke_expression_over_column( + col, glotexp.ArrayContains, expression=value_col.expression + ) def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) -def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: +def slice( + x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int] +) -> Column: start_col = start if isinstance(start, Column) else lit(start) length_col = length if isinstance(length, Column) else lit(length) return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) -def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: +def array_join( + col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None +) -> Column: if null_replacement is not None: - return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) + return Column.invoke_anonymous_function( + col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement) + ) return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) def concat(*cols: ColumnOrName) -> Column: if len(cols) == 1: return Column.invoke_anonymous_function(cols[0], "CONCAT") - return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) + return Column.invoke_anonymous_function( + cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]] + ) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) -def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: +def sequence( + start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None +) -> Column: if step is not None: return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) return Column.invoke_anonymous_function(start, "SEQUENCE", stop) @@ -1103,12 +1144,15 @@ def aggregate( merge_exp = _get_lambda_from_func(merge) if finish is not None: finish_exp = _get_lambda_from_func(finish) - return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) + return Column.invoke_anonymous_function( + col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) + ) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) def transform( - col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]] + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) @@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) -def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column: +def filter( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) -def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column: +def zip_with( + left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column] +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) @@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]: def _get_lambda_from_func(lambda_expression: t.Callable): - variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames] + variables = [ + glotexp.to_identifier(x, quoted=_lambda_quoted(x)) + for x in lambda_expression.__code__.co_varnames + ] return glotexp.Lambda( this=lambda_expression(*[Column(x) for x in variables]).expression, expressions=variables, diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py index 947aace..ba27c17 100644 --- a/sqlglot/dataframe/sql/group.py +++ b/sqlglot/dataframe/sql/group.py @@ -17,7 +17,9 @@ class GroupedData: self.last_op = last_op self.group_by_cols = group_by_cols - def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]: + def _get_function_applied_columns( + self, func_name: str, cols: t.Tuple[str, ...] + ) -> t.List[Column]: func_name = func_name.lower() return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] @@ -30,9 +32,9 @@ class GroupedData: ) cols = self._df._ensure_and_normalize_cols(columns) - expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select( - *[x.expression for x in self.group_by_cols + cols], append=False - ) + expression = self._df.expression.group_by( + *[x.expression for x in self.group_by_cols] + ).select(*[x.expression for x in self.group_by_cols + cols], append=False) return self._df.copy(expression=expression) def count(self) -> DataFrame: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 1513946..75feba7 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) -def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): +def replace_alias_name_with_cte_name( + spark: SparkSession, expression_context: exp.Select, id: exp.Identifier +): if id.alias_or_name in spark.name_to_sequence_id_mapping: for cte in reversed(expression_context.ctes): if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: @@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name( # id then it keeps that reference. This handles the weird edge case in spark that shouldn't # be common in practice if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: - join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] - ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] + join_table_aliases = [ + x.alias_or_name for x in get_tables_from_expression_with_join(expression_context) + ] + ctes_in_join = [ + cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases + ] if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: assert len(ctes_in_join) == 2 _set_alias_name(id, ctes_in_join[0].alias_or_name) @@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str): def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: - values = ensure_list(values) results = [] for value in values: if isinstance(value, str): diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 4830035..febc664 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -19,12 +19,19 @@ class DataFrameReader: from sqlglot.dataframe.sql.dataframe import DataFrame sqlglot.schema.add_table(tableName) - return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName))) + return DataFrame( + self.spark, + exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), + ) class DataFrameWriter: def __init__( - self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False + self, + df: DataFrame, + spark: t.Optional[SparkSession] = None, + mode: t.Optional[str] = None, + by_name: bool = False, ): self._df = df self._spark = spark or df.spark @@ -33,7 +40,10 @@ class DataFrameWriter: def copy(self, **kwargs) -> DataFrameWriter: return DataFrameWriter( - **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()} + **{ + k[1:] if k.startswith("_") else k: v + for k, v in object_to_dict(self, **kwargs).items() + } ) def sql(self, **kwargs) -> t.List[str]: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 1ea86d1..8cb16ef 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -67,13 +67,20 @@ class SparkSession: data_expressions = [ exp.Tuple( - expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values())) + expressions=list( + map( + lambda x: F.lit(x).expression, + row if not isinstance(row, dict) else row.values(), + ) + ) ) for row in data ] sel_columns = [ - F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression + F.col(name).cast(data_type).alias(name).expression + if data_type is not None + else F.col(name).expression for name, data_type in column_mapping.items() ] @@ -106,10 +113,12 @@ class SparkSession: select_expression.set("with", expression.args.get("with")) expression.set("with", None) del expression.args["expression"] - df = DataFrame(self, select_expression, output_expression_container=expression) + df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore df = df._convert_leaf_to_cte() else: - raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.") + raise ValueError( + "Unknown expression type provided in the SQL. Please create an issue with the SQL." + ) return df @property diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py index dc5c05a..a63e505 100644 --- a/sqlglot/dataframe/sql/types.py +++ b/sqlglot/dataframe/sql/types.py @@ -158,7 +158,11 @@ class MapType(DataType): class StructField(DataType): def __init__( - self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None + self, + name: str, + dataType: DataType, + nullable: bool = True, + metadata: t.Optional[t.Dict[str, t.Any]] = None, ): self.name = name self.dataType = dataType diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index 842f366..c54c07e 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -74,8 +74,13 @@ class WindowSpec: window_spec.expression.args["order"].set("expressions", order_by) return window_spec - def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: - kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None} + def _calc_start_end( + self, start: int, end: int + ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: + kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { + "start_side": None, + "end_side": None, + } if start == Window.currentRow: kwargs["start"] = "CURRENT ROW" else: @@ -83,7 +88,9 @@ class WindowSpec: **kwargs, **{ "start_side": "PRECEDING", - "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression, + "start": "UNBOUNDED" + if start <= Window.unboundedPreceding + else F.lit(start).expression, }, } if end == Window.currentRow: @@ -93,7 +100,9 @@ class WindowSpec: **kwargs, **{ "end_side": "FOLLOWING", - "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression, + "end": "UNBOUNDED" + if end >= Window.unboundedFollowing + else F.lit(end).expression, }, } return kwargs @@ -103,7 +112,10 @@ class WindowSpec: spec = self._calc_start_end(start, end) spec["kind"] = "ROWS" window_spec.expression.set( - "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + "spec", + exp.WindowSpec( + **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} + ), ) return window_spec @@ -112,6 +124,9 @@ class WindowSpec: spec = self._calc_start_end(start, end) spec["kind"] = "RANGE" window_spec.expression.set( - "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + "spec", + exp.WindowSpec( + **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} + ), ) return window_spec diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 62d042e..5bbff9d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,21 +1,21 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, inline_array_sql, no_ilike_sql, rename_func, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=interval.args.get("unit"), ) @@ -23,6 +23,13 @@ def _date_add(expression_class): return func +def _date_trunc(args): + unit = seq_get(args, 1) + if isinstance(unit, exp.Column): + unit = exp.Var(this=unit.name) + return exp.DateTrunc(this=seq_get(args, 0), expression=unit) + + def _date_add_sql(data_type, kind): def func(self, expression): this = self.sql(expression, "this") @@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression): structs = [] for row in rows: aliases = [ - exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) + exp.alias_(value, column_name) + for value, column_name in zip(row, expression.args["alias"].args["columns"]) ] structs.append(exp.Struct(expressions=aliases)) unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) @@ -89,18 +97,19 @@ class BigQuery(Dialect): "%j": "%-j", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = [ (prefix + quote, quote) if prefix else quote for quote in ["'", '"', '"""', "'''"] for prefix in ["", "r", "R"] ] + COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, "GEOGRAPHY": TokenType.GEOGRAPHY, @@ -111,35 +120,40 @@ class BigQuery(Dialect): "WINDOW": TokenType.WINDOW, "NOT DETERMINISTIC": TokenType.VOLATILE, } + KEYWORDS.pop("DIV") - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, + "DATE_TRUNC": _date_trunc, "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), "DATETIME_SUB": _date_add(exp.DatetimeSub), "TIME_SUB": _date_add(exp.TimeSub), "TIMESTAMP_SUB": _date_add(exp.TimestampSub), - "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)), + "PARSE_TIMESTAMP": lambda args: exp.StrToTime( + this=seq_get(args, 1), format=seq_get(args, 0) + ), } NO_PAREN_FUNCTIONS = { - **Parser.NO_PAREN_FUNCTIONS, + **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { - *Parser.NESTED_TYPE_TOKENS, + *parser.Parser.NESTED_TYPE_TOKENS, TokenType.TABLE, } - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), @@ -148,6 +162,7 @@ class BigQuery(Dialect): exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.ILike: no_ilike_sql, + exp.IntDiv: rename_func("DIV"), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -157,11 +172,13 @@ class BigQuery(Dialect): exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, - exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", + exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" + if e.name == "IMMUTABLE" + else "NOT DETERMINISTIC", } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.INT: "INT64", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index f446e6d..332b4c1 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql -from sqlglot.generator import Generator -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.parser import parse_var_map +from sqlglot.tokens import TokenType def _lower_func(sql): @@ -14,11 +15,12 @@ class ClickHouse(Dialect): normalize_functions = None null_ordering = "nulls_are_last" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): + COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "FINAL": TokenType.FINAL, "DATETIME64": TokenType.DATETIME, "INT8": TokenType.TINYINT, @@ -30,9 +32,9 @@ class ClickHouse(Dialect): "TUPLE": TokenType.STRUCT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "MAP": parse_var_map, } @@ -44,11 +46,11 @@ class ClickHouse(Dialect): return this - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.MAP: "Map", @@ -63,7 +65,7 @@ class ClickHouse(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 9dc3c38..2498c62 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.spark import Spark @@ -15,7 +17,7 @@ class Databricks(Spark): class Generator(Spark.Generator): TRANSFORMS = { - **Spark.Generator.TRANSFORMS, + **Spark.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 33985a7..3af08bb 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,8 +1,11 @@ +from __future__ import annotations + +import typing as t from enum import Enum from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import flatten, list_get +from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time from sqlglot.tokens import Tokenizer @@ -32,7 +35,7 @@ class Dialects(str, Enum): class _Dialect(type): - classes = {} + classes: t.Dict[str, Dialect] = {} @classmethod def __getitem__(cls, key): @@ -56,19 +59,30 @@ class _Dialect(type): klass.generator_class = getattr(klass, "Generator", Generator) klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] - - if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class._IDENTIFIERS.items() + )[0] + + if ( + klass.tokenizer_class._BIT_STRINGS + and exp.BitString not in klass.generator_class.TRANSFORMS + ): bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.BitString ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" - if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._HEX_STRINGS + and exp.HexString not in klass.generator_class.TRANSFORMS + ): hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.HexString ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" - if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._BYTE_STRINGS + and exp.ByteString not in klass.generator_class.TRANSFORMS + ): be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.ByteString @@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect): index_offset = 0 unnest_column_only = False alias_post_tablesample = False - normalize_functions = "upper" + normalize_functions: t.Optional[str] = "upper" null_ordering = "nulls_are_small" date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping = {} + time_mapping: t.Dict[str, str] = {} # autofilled quote_start = None @@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect): "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPE, + "escape": self.tokenizer_class.ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression): def if_sql(self, expression): - expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false")) + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) return f"IF({expressions})" @@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None): def _format_time(args): return exp_class( - this=list_get(args, 0), + this=seq_get(args, 0), format=Dialect[dialect].format_time( - list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) ), ) @@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression): "expressions", [e for e in schema.expressions if e not in partitions], ) - prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) + prop.replace( + exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) + ) expression.set("this", schema) return self.create_sql(expression) @@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression): def parse_date_delta(exp_class, unit_mapping=None): def inner_func(args): unit_based = len(args) == 3 - this = list_get(args, 2) if unit_based else list_get(args, 0) - expression = list_get(args, 1) if unit_based else list_get(args, 1) - unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") + this = seq_get(args, 2) if unit_based else seq_get(args, 0) + expression = seq_get(args, 1) if unit_based else seq_get(args, 1) + unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit return exp_class(this=this, expression=expression, unit=unit) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f3ff6d3..781edff 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _unix_to_time(self, expression): @@ -61,11 +61,14 @@ def _sort_array_sql(self, expression): def _sort_array_reverse(args): - return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) + return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE) def _struct_pack_sql(self, expression): - args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions] + args = [ + self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) + for e in expression.expressions + ] return f"STRUCT_PACK({', '.join(args)})" @@ -76,15 +79,15 @@ def _datatype_sql(self, expression): class DuckDB(Dialect): - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, @@ -92,7 +95,7 @@ class DuckDB(Dialect): "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( this=exp.Div( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Literal.number(1000), ) ), @@ -112,11 +115,11 @@ class DuckDB(Dialect): "UNNEST": exp.Explode.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: rename_func("LIST_VALUE"), exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -160,7 +163,7 @@ class DuckDB(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 03049ff..ed7357c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, var_map_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer +from sqlglot.helper import seq_get +from sqlglot.parser import parse_var_map # (FuncType, Multiplier) DATE_DELTA_INTERVAL = { @@ -34,7 +34,9 @@ def _add_date_sql(self, expression): unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( - int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression + int(expression.text("expression")) * multiplier + if expression.expression.is_number + else expression.expression ) modified_increment = exp.Literal.number(modified_increment) return f"{func}({self.format_args(expression.this, modified_increment.this)})" @@ -165,10 +167,10 @@ class Hive(Dialect): dateint_format = "'yyyyMMdd'" time_format = "'yyyy-MM-dd HH:mm:ss'" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] ENCODE = "utf-8" NUMERIC_LITERALS = { @@ -180,40 +182,44 @@ class Hive(Dialect): "BD": "DECIMAL", } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), unit=exp.Literal.string("DAY"), ), "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=list_get(args, 0)), - expression=exp.TsOrDsToDate(this=list_get(args, 1)), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DATE_SUB": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Mul( - this=list_get(args, 1), + this=seq_get(args, 1), expression=exp.Literal.number(-1), ), unit=exp.Literal.string("DAY"), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": lambda args: exp.StrPosition( - this=list_get(args, 1), - substr=list_get(args, 0), - position=list_get(args, 2), + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) ), - "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -226,15 +232,16 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.AnonymousProperty: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 524390f..e742640 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,4 +1,8 @@ -from sqlglot import exp +from __future__ import annotations + +import typing as t + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, no_ilike_sql, @@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType + + +def _show_parser(*args, **kwargs): + def _parse(self): + return self._parse_show_mysql(*args, **kwargs) + + return _parse def _date_trunc_sql(self, expression): - unit = expression.text("unit").lower() + unit = expression.name.lower() - this = self.sql(expression.this) + expr = self.sql(expression.expression) if unit == "day": - return f"DATE({this})" + return f"DATE({expr})" if unit == "week": - concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" elif unit == "month": - concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" date_format = "%Y %c %e" elif unit == "quarter": - concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" date_format = "%Y %c %e" elif unit == "year": - concat = f"CONCAT(YEAR({this}), ' 1 1')" + concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: self.unsupported("Unexpected interval unit: {unit}") - return f"DATE({this})" + return f"DATE({expr})" return f"STR_TO_DATE({concat}, '{date_format}')" def _str_to_date(args): - date_format = MySQL.format_time(list_get(args, 1)) - return exp.StrToDate(this=list_get(args, 0), format=date_format) + date_format = MySQL.format_time(seq_get(args, 1)) + return exp.StrToDate(this=seq_get(args, 0), format=date_format) def _str_to_date_sql(self, expression): @@ -66,9 +75,9 @@ def _trim_sql(self, expression): def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=exp.Literal.string(interval.text("unit").lower()), ) @@ -101,15 +110,16 @@ class MySQL(Dialect): "%l": "%-I", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] + ESCAPES = ["'", "\\"] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -156,20 +166,23 @@ class MySQL(Dialect): "_UTF32": TokenType.INTRODUCER, "_UTF8MB3": TokenType.INTRODUCER, "_UTF8MB4": TokenType.INTRODUCER, + "@@": TokenType.SESSION_PARAMETER, } - class Parser(Parser): + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} + + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -178,15 +191,212 @@ class MySQL(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), } - class Generator(Generator): + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + TokenType.SET: lambda self: self._parse_set(), + } + + SHOW_PARSERS = { + "BINARY LOGS": _show_parser("BINARY LOGS"), + "MASTER LOGS": _show_parser("BINARY LOGS"), + "BINLOG EVENTS": _show_parser("BINLOG EVENTS"), + "CHARACTER SET": _show_parser("CHARACTER SET"), + "CHARSET": _show_parser("CHARACTER SET"), + "COLLATION": _show_parser("COLLATION"), + "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True), + "COLUMNS": _show_parser("COLUMNS", target="FROM"), + "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True), + "CREATE EVENT": _show_parser("CREATE EVENT", target=True), + "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True), + "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True), + "CREATE TABLE": _show_parser("CREATE TABLE", target=True), + "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True), + "CREATE VIEW": _show_parser("CREATE VIEW", target=True), + "DATABASES": _show_parser("DATABASES"), + "ENGINE": _show_parser("ENGINE", target=True), + "STORAGE ENGINES": _show_parser("ENGINES"), + "ENGINES": _show_parser("ENGINES"), + "ERRORS": _show_parser("ERRORS"), + "EVENTS": _show_parser("EVENTS"), + "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True), + "FUNCTION STATUS": _show_parser("FUNCTION STATUS"), + "GRANTS": _show_parser("GRANTS", target="FOR"), + "INDEX": _show_parser("INDEX", target="FROM"), + "MASTER STATUS": _show_parser("MASTER STATUS"), + "OPEN TABLES": _show_parser("OPEN TABLES"), + "PLUGINS": _show_parser("PLUGINS"), + "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True), + "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"), + "PRIVILEGES": _show_parser("PRIVILEGES"), + "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True), + "PROCESSLIST": _show_parser("PROCESSLIST"), + "PROFILE": _show_parser("PROFILE"), + "PROFILES": _show_parser("PROFILES"), + "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"), + "REPLICAS": _show_parser("REPLICAS"), + "SLAVE HOSTS": _show_parser("REPLICAS"), + "REPLICA STATUS": _show_parser("REPLICA STATUS"), + "SLAVE STATUS": _show_parser("REPLICA STATUS"), + "GLOBAL STATUS": _show_parser("STATUS", global_=True), + "SESSION STATUS": _show_parser("STATUS"), + "STATUS": _show_parser("STATUS"), + "TABLE STATUS": _show_parser("TABLE STATUS"), + "FULL TABLES": _show_parser("TABLES", full=True), + "TABLES": _show_parser("TABLES"), + "TRIGGERS": _show_parser("TRIGGERS"), + "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True), + "SESSION VARIABLES": _show_parser("VARIABLES"), + "VARIABLES": _show_parser("VARIABLES"), + "WARNINGS": _show_parser("WARNINGS"), + } + + SET_PARSERS = { + "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), + "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), + "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), + "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), + "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "NAMES": lambda self: self._parse_set_item_names(), + } + + PROFILE_TYPES = { + "ALL", + "BLOCK IO", + "CONTEXT SWITCHES", + "CPU", + "IPC", + "MEMORY", + "PAGE FAULTS", + "SOURCE", + "SWAPS", + } + + def _parse_show_mysql(self, this, target=False, full=None, global_=None): + if target: + if isinstance(target, str): + self._match_text(target) + target_id = self._parse_id_var() + else: + target_id = None + + log = self._parse_string() if self._match_text("IN") else None + + if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: + position = self._parse_number() if self._match_text("FROM") else None + db = None + else: + position = None + db = self._parse_id_var() if self._match_text("FROM") else None + + channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + + like = self._parse_string() if self._match_text("LIKE") else None + where = self._parse_where() + + if this == "PROFILE": + types = self._parse_csv(self._parse_show_profile_type) + query = self._parse_number() if self._match_text("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text("OFFSET") else None + limit = self._parse_number() if self._match_text("LIMIT") else None + else: + types, query = None, None + offset, limit = self._parse_oldstyle_limit() + + mutex = True if self._match_text("MUTEX") else None + mutex = False if self._match_text("STATUS") else mutex + + return self.expression( + exp.Show, + this=this, + target=target_id, + full=full, + log=log, + position=position, + db=db, + channel=channel, + like=like, + where=where, + types=types, + query=query, + offset=offset, + limit=limit, + mutex=mutex, + **{"global": global_}, + ) + + def _parse_show_profile_type(self): + for type_ in self.PROFILE_TYPES: + if self._match_text(*type_.split(" ")): + return exp.Var(this=type_) + return None + + def _parse_oldstyle_limit(self): + limit = None + offset = None + if self._match_text("LIMIT"): + parts = self._parse_csv(self._parse_number) + if len(parts) == 1: + limit = parts[0] + elif len(parts) == 2: + limit = parts[1] + offset = parts[0] + return offset, limit + + def _default_parse_set_item(self): + return self._parse_set_item_assignment(kind=None) + + def _parse_set_item_assignment(self, kind): + left = self._parse_primary() or self._parse_id_var() + if not self._match(TokenType.EQ): + self.raise_error("Expected =") + right = self._parse_statement() or self._parse_id_var() + + this = self.expression( + exp.EQ, + this=left, + expression=right, + ) + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_charset(self, kind): + this = self._parse_string() or self._parse_id_var() + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_names(self): + charset = self._parse_string() or self._parse_id_var() + if self._match_text("COLLATE"): + collate = self._parse_string() or self._parse_id_var() + else: + collate = None + return self.expression( + exp.SetItem, + this=charset, + collate=collate, + kind="NAMES", + ) + + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, @@ -199,6 +409,8 @@ class MySQL(Dialect): exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, + exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), + exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), } ROOT_PROPERTIES = { @@ -209,4 +421,69 @@ class MySQL(Dialect): exp.SchemaCommentProperty, } - WITH_PROPERTIES = {} + WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() + + def show_sql(self, expression): + this = f" {expression.name}" + full = " FULL" if expression.args.get("full") else "" + global_ = " GLOBAL" if expression.args.get("global") else "" + + target = self.sql(expression, "target") + target = f" {target}" if target else "" + if expression.name in {"COLUMNS", "INDEX"}: + target = f" FROM{target}" + elif expression.name == "GRANTS": + target = f" FOR{target}" + + db = self._prefixed_sql("FROM", expression, "db") + + like = self._prefixed_sql("LIKE", expression, "like") + where = self.sql(expression, "where") + + types = self.expressions(expression, key="types") + types = f" {types}" if types else types + query = self._prefixed_sql("FOR QUERY", expression, "query") + + if expression.name == "PROFILE": + offset = self._prefixed_sql("OFFSET", expression, "offset") + limit = self._prefixed_sql("LIMIT", expression, "limit") + else: + offset = "" + limit = self._oldstyle_limit_sql(expression) + + log = self._prefixed_sql("IN", expression, "log") + position = self._prefixed_sql("FROM", expression, "position") + + channel = self._prefixed_sql("FOR CHANNEL", expression, "channel") + + if expression.name == "ENGINE": + mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS" + else: + mutex_or_status = "" + + return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" + + def _prefixed_sql(self, prefix, expression, arg): + sql = self.sql(expression, arg) + if not sql: + return "" + return f" {prefix} {sql}" + + def _oldstyle_limit_sql(self, expression): + limit = self.sql(expression, "limit") + offset = self.sql(expression, "offset") + if limit: + limit_offset = f"{offset}, {limit}" if offset else limit + return f" LIMIT {limit_offset}" + return "" + + def setitem_sql(self, expression): + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + return f"{kind}{this}{collate}" + + def set_sql(self, expression): + return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 144dba5..3bc1109 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,8 +1,9 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql -from sqlglot.generator import Generator from sqlglot.helper import csv -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType def _limit_sql(self, expression): @@ -36,9 +37,9 @@ class Oracle(Dialect): "YYYY": "%Y", # 2015 } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -49,11 +50,12 @@ class Oracle(Dialect): exp.DataType.Type.NVARCHAR: "NVARCHAR2", exp.DataType.Type.TEXT: "CLOB", exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -86,9 +88,9 @@ class Oracle(Dialect): def table_sql(self, expression): return super().table_sql(expression, sep=" ") - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, "NVARCHAR2": TokenType.NVARCHAR, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 459e926..553a73b 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, str_position_sql, ) -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess @@ -160,12 +160,12 @@ class Postgres(Dialect): "YYYY": "%Y", # 2015 } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, "COMMENT ON": TokenType.COMMENT_ON, @@ -179,31 +179,32 @@ class Postgres(Dialect): } QUOTES = ["'", "$$"] SINGLE_TOKENS = { - **Tokenizer.SINGLE_TOKENS, + **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", + exp.DataType.Type.VARBINARY: "BYTEA", exp.DataType.Type.DATETIME: "TIMESTAMP", } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ColumnDef: preprocess( [ _auto_increment_to_serial, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index a2d392c..11ea778 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, ) from sqlglot.dialects.mysql import MySQL -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _approx_distinct_sql(self, expression): @@ -110,30 +110,29 @@ class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" time_format = "'%Y-%m-%d %H:%i:%S'" - time_mapping = MySQL.time_mapping + time_mapping = MySQL.time_mapping # type: ignore - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, - "VARBINARY": TokenType.BINARY, + **tokens.Tokenizer.KEYWORDS, "ROW": TokenType.STRUCT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( - this=list_get(args, 2), - expression=list_get(args, 1), - unit=list_get(args, 0), + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), ), "DATE_DIFF": lambda args: exp.DateDiff( - this=list_get(args, 2), - expression=list_get(args, 1), - unit=list_get(args, 0), + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), @@ -143,7 +142,7 @@ class Presto(Dialect): "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -159,7 +158,7 @@ class Presto(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -169,8 +168,8 @@ class Presto(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index e1f7b78..a9b12fb 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { - **Postgres.time_mapping, + **Postgres.time_mapping, # type: ignore "MON": "%b", "HH": "%H", } class Tokenizer(Postgres.Tokenizer): - ESCAPE = "\\" + ESCAPES = ["\\"] KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, + **Postgres.Tokenizer.KEYWORDS, # type: ignore "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, - "VARBYTE": TokenType.BINARY, + "VARBYTE": TokenType.VARBINARY, "SIMILAR TO": TokenType.SIMILAR_TO, } class Generator(Postgres.Generator): TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, + **Postgres.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BINARY: "VARBYTE", + exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 3b97e6d..d1aaded 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import ( rename_func, ) from sqlglot.expressions import Literal -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _check_int(s): @@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args): # case: <numeric_expr> [ , <scale> ] if second_arg.name not in ["0", "3", "9"]: - raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9") + raise ValueError( + f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" + ) if second_arg.name == "0": timescale = exp.UnixToTime.SECONDS @@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime(this=first_arg, scale=timescale) - first_arg = list_get(args, 0) + first_arg = seq_get(args, 0) if not isinstance(first_arg, Literal): # case: <variant_expr> return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) @@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime.from_arg_list(args) -def _unix_to_time(self, expression): +def _unix_to_time_sql(self, expression): scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -132,9 +134,9 @@ class Snowflake(Dialect): "ff6": "%f", } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "IFF": exp.If.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, @@ -143,18 +145,18 @@ class Snowflake(Dialect): } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, } FUNC_TOKENS = { - *Parser.FUNC_TOKENS, + *parser.Parser.FUNC_TOKENS, TokenType.RLIKE, TokenType.TABLE, } COLUMN_OPERATORS = { - **Parser.COLUMN_OPERATORS, + **parser.Parser.COLUMN_OPERATORS, # type: ignore TokenType.COLON: lambda self, this, path: self.expression( exp.Bracket, this=this, @@ -163,21 +165,21 @@ class Snowflake(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPE = "\\" + ESCAPES = ["\\"] SINGLE_TOKENS = { - **Tokenizer.SINGLE_TOKENS, + **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, } KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "QUALIFY": TokenType.QUALIFY, "DOUBLE PRECISION": TokenType.DOUBLE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, @@ -187,15 +189,15 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, } - class Generator(Generator): + class Generator(generator.Generator): CREATE_TRANSIENT = True TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.If: rename_func("IFF"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, + exp.UnixToTime: _unix_to_time_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Array: inline_array_sql, exp.StrPosition: rename_func("POSITION"), @@ -204,7 +206,7 @@ class Snowflake(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 572f411..4e404b8 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func from sqlglot.dialects.hive import Hive -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get def _create_sql(self, e): @@ -46,36 +47,36 @@ def _unix_to_time(self, expression): class Spark(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, + **Hive.Parser.FUNCTIONS, # type: ignore "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Literal.number(1), - length=list_get(args, 1), + length=seq_get(args, 1), ), "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "RIGHT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Sub( - this=exp.Length(this=list_get(args, 0)), - expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), + this=exp.Length(this=seq_get(args, 0)), + expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), ), - length=list_get(args, 1), + length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -88,14 +89,14 @@ class Spark(Hive): class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, + **Hive.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } TRANSFORMS = { - **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}}, + **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", @@ -114,6 +115,8 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), } + TRANSFORMS.pop(exp.ArraySort) + TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 62b7617..8c9fb76 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, ) -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType class SQLite(Dialect): - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, - "VARBINARY": TokenType.BINARY, + **tokens.Tokenizer.KEYWORDS, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -46,6 +45,7 @@ class SQLite(Dialect): exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", } TOKEN_MAPPING = { @@ -53,7 +53,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 0cba6fe..3519c09 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL class StarRocks(MySQL): - class Generator(MySQL.Generator): + class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", @@ -13,7 +15,7 @@ class StarRocks(MySQL): } TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, + **MySQL.Generator.TRANSFORMS, # type: ignore exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), @@ -22,3 +24,4 @@ class StarRocks(MySQL): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), } + TRANSFORMS.pop(exp.DateTrunc) diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 45aa041..63e7275 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,7 +1,7 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser from sqlglot.dialects.dialect import Dialect -from sqlglot.generator import Generator -from sqlglot.parser import Parser def _if_sql(self, expression): @@ -20,17 +20,17 @@ def _count_sql(self, expression): class Tableau(Dialect): - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.If: _if_sql, exp.Coalesce: _coalesce_sql, exp.Count: _count_sql, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 9a6f7fe..c7b34fe 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.presto import Presto @@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): class Generator(Presto.Generator): TRANSFORMS = { - **Presto.Generator.TRANSFORMS, + **Presto.Generator.TRANSFORMS, # type: ignore exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 0f93c75..a233d4b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,15 +1,22 @@ +from __future__ import annotations + import re -from sqlglot import exp +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.expressions import DataType -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType -FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"} +FULL_FORMAT_TIME_MAPPING = { + "weekday": "%A", + "dw": "%A", + "w": "%A", + "month": "%B", + "mm": "%B", + "m": "%B", +} DATE_DELTA_INTERVAL = { "year": "year", "yyyy": "year", @@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( - this=list_get(args, 1), + this=seq_get(args, 1), format=exp.Literal.string( format_time( - list_get(args, 0).name or (TSQL.time_format if default is True else default), - {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, + seq_get(args, 0).name or (TSQL.time_format if default is True else default), + {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} + if full_format_mapping + else TSQL.time_mapping, ) ), ) @@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def parse_format(args): - fmt = list_get(args, 1) + fmt = seq_get(args, 1) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) if number_fmt: - return exp.NumberToStr(this=list_get(args, 0), format=fmt) + return exp.NumberToStr(this=seq_get(args, 0), format=fmt) return exp.TimeToStr( - this=list_get(args, 0), + this=seq_get(args, 0), format=exp.Literal.string( format_time(fmt.name, TSQL.format_time_mapping) if len(fmt.name) == 1 @@ -188,11 +197,11 @@ class TSQL(Dialect): "Y": "%a %Y", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "BIT": TokenType.BOOLEAN, "REAL": TokenType.FLOAT, "NTEXT": TokenType.TEXT, @@ -200,7 +209,6 @@ class TSQL(Dialect): "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "TIME": TokenType.TIMESTAMP, - "VARBINARY": TokenType.BINARY, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "SMALLMONEY": TokenType.SMALLMONEY, @@ -213,9 +221,9 @@ class TSQL(Dialect): "TOP": TokenType.TOP, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), @@ -243,14 +251,16 @@ class TSQL(Dialect): this = self._parse_column() # Retrieve length of datatype and override to default if not specified - if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: + if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) # Check whether a conversion with format is applicable if self._match(TokenType.COMMA): format_val = self._parse_number().name if format_val not in TSQL.convert_format_mapping: - raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}") + raise ValueError( + f"CONVERT function at T-SQL does not support format style {format_val}" + ) format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) # Check whether the convert entails a string to date format @@ -272,9 +282,9 @@ class TSQL(Dialect): # Entails a simple cast without any format requirement return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", @@ -283,7 +293,7 @@ class TSQL(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 0567c12..2d959ab 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -4,7 +4,7 @@ from heapq import heappop, heappush from sqlglot import Dialect from sqlglot import expressions as exp -from sqlglot.helper import ensure_list +from sqlglot.helper import ensure_collection @dataclass(frozen=True) @@ -116,7 +116,9 @@ class ChangeDistiller: source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: - edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set)) + edit_script.extend( + self._generate_move_edits(source_node, target_node, matching_set) + ) edit_script.append(Keep(source_node, target_node)) else: edit_script.append(Update(source_node, target_node)) @@ -158,13 +160,16 @@ class ChangeDistiller: max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) if max_leaves_num: common_leaves_num = sum( - 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 + for s, t in leaves_matching_set ) leaf_similarity_score = common_leaves_num / max_leaves_num else: leaf_similarity_score = 0.0 - adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 + adjusted_t = ( + self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 + ) if leaf_similarity_score >= 0.8 or ( leaf_similarity_score >= adjusted_t @@ -201,7 +206,10 @@ class ChangeDistiller: matching_set = set() while candidate_matchings: _, _, source_leaf, target_leaf = heappop(candidate_matchings) - if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes: + if ( + id(source_leaf) in self._unmatched_source_nodes + and id(target_leaf) in self._unmatched_target_nodes + ): matching_set.add((id(source_leaf), id(target_leaf))) self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_target_nodes.remove(id(target_leaf)) @@ -241,8 +249,7 @@ def _get_leaves(expression): has_child_exprs = False for a in expression.args.values(): - nodes = ensure_list(a) - for node in nodes: + for node in ensure_collection(a): if isinstance(node, exp.Expression): has_child_exprs = True yield from _get_leaves(node) @@ -268,7 +275,7 @@ def _expression_only_args(expression): args = [] if expression: for a in expression.args.values(): - args.extend(ensure_list(a)) + args.extend(ensure_collection(a)) return [a for a in args if isinstance(a, exp.Expression)] diff --git a/sqlglot/errors.py b/sqlglot/errors.py index 89aa935..2ef908f 100644 --- a/sqlglot/errors.py +++ b/sqlglot/errors.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from enum import auto from sqlglot.helper import AutoName @@ -30,7 +33,11 @@ class OptimizeError(SqlglotError): pass -def concat_errors(errors, maximum): +class SchemaError(SqlglotError): + pass + + +def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: msg = [str(e) for e in errors[:maximum]] remaining = len(errors) - maximum if remaining > 0: diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index d265a2c..393347b 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,6 +19,7 @@ class Context: env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables + self._table = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -29,8 +30,27 @@ class Context: def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) + @property + def table(self): + if self._table is None: + self._table = list(self.tables.values())[0] + for other in self.tables.values(): + if self._table.columns != other.columns: + raise Exception(f"Columns are different.") + if len(self._table.rows) != len(other.rows): + raise Exception(f"Rows are different.") + return self._table + + @property + def columns(self): + return self.table.columns + def __iter__(self): - return self.table_iter(list(self.tables)[0]) + self.env["scope"] = self.row_readers + for i in range(len(self.table.rows)): + for table in self.tables.values(): + reader = table[i] + yield reader, self def table_iter(self, table): self.env["scope"] = self.row_readers @@ -38,8 +58,8 @@ class Context: for reader in self.tables[table]: yield reader, self - def sort(self, table, key): - table = self.tables[table] + def sort(self, key): + table = self.table def sort_key(row): table.reader.row = row @@ -47,20 +67,20 @@ class Context: table.rows.sort(key=sort_key) - def set_row(self, table, row): - self.row_readers[table].row = row + def set_row(self, row): + for table in self.tables.values(): + table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, table, index): - self.row_readers[table].row = self.tables[table].rows[index] + def set_index(self, index): + for table in self.tables.values(): + table[index] self.env["scope"] = self.row_readers - def set_range(self, table, start, end): - self.range_readers[table].range = range(start, end) + def set_range(self, start, end): + for name in self.tables: + self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __getitem__(self, table): - return self.env["scope"][table] - def __contains__(self, table): return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 9c49dd1..bbe6c81 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -2,6 +2,8 @@ import datetime import re import statistics +from sqlglot.helper import PYTHON_VERSION + class reverse_key: def __init__(self, obj): @@ -25,7 +27,7 @@ ENV = { "str": str, "desc": reverse_key, "SUM": sum, - "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, + "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore "COUNT": lambda acc: sum(1 for e in acc if e is not None), "MAX": max, "MIN": min, diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index fcb016b..7d1db32 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -1,15 +1,14 @@ import ast import collections import itertools +import math -from sqlglot import exp, planner +from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.executor.context import Context from sqlglot.executor.env import ENV from sqlglot.executor.table import Table -from sqlglot.generator import Generator from sqlglot.helper import csv_reader -from sqlglot.tokens import Tokenizer class PythonExecutor: @@ -26,7 +25,11 @@ class PythonExecutor: while queue: node = queue.pop() context = self.context( - {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } ) running.add(node) @@ -76,13 +79,10 @@ class PythonExecutor: return Table(expression.alias_or_name for expression in expressions) def scan(self, step, context): - if hasattr(step, "source"): - source = step.source + source = step.source - if isinstance(source, exp.Expression): - source = source.name or source.alias - else: - source = step.name + if isinstance(source, exp.Expression): + source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) @@ -96,14 +96,12 @@ class PythonExecutor: if projections: sink = self.table(step.projections) - elif source in context: - sink = Table(context[source].columns) else: sink = None for reader, ctx in table_iter: if sink is None: - sink = Table(ctx[source].columns) + sink = Table(reader.columns) if condition and not ctx.eval(condition): continue @@ -135,98 +133,79 @@ class PythonExecutor: types.append(type(ast.literal_eval(v))) except (ValueError, SyntaxError): types.append(str) - context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) - yield context[alias], context + context.set_row(tuple(t(v) for t, v in zip(types, row))) + yield context.table.reader, context def join(self, step, context): source = step.name - join_context = self.context({source: context.tables[source]}) - - def merge_context(ctx, table): - # create a new context where all existing tables are mapped to a new one - return self.context({name: table for name in ctx.tables}) + source_table = context.tables[source] + source_context = self.context({source: source_table}) + column_ranges = {source: range(0, len(source_table.columns))} for name, join in step.joins.items(): - join_context = self.context({**join_context.tables, name: context.tables[name]}) + table = context.tables[name] + start = max(r.stop for r in column_ranges.values()) + column_ranges[name] = range(start, len(table.columns) + start) + join_context = self.context({name: table}) if join.get("source_key"): - table = self.hash_join(join, source, name, join_context) + table = self.hash_join(join, source_context, join_context) else: - table = self.nested_loop_join(join, source, name, join_context) + table = self.nested_loop_join(join, source_context, join_context) - join_context = merge_context(join_context, table) - - # apply projections or conditions - context = self.scan(step, join_context) + source_context = self.context( + { + name: Table(table.columns, table.rows, column_range) + for name, column_range in column_ranges.items() + } + ) - # use the scan context since it returns a single table - # otherwise there are no projections so all other tables are still in scope - if step.projections: - return context + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) - return merge_context(join_context, context.tables[source]) + if not condition or not projections: + return source_context - def nested_loop_join(self, _join, a, b, context): - table = Table(context.tables[a].columns + context.tables[b].columns) + sink = self.table(step.projections if projections else source_context.columns) - for reader_a, _ in context.table_iter(a): - for reader_b, _ in context.table_iter(b): - table.append(reader_a.row + reader_b.row) + for reader, ctx in join_context: + if condition and not ctx.eval(condition): + continue - return table + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) - def hash_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) + if len(sink) >= step.limit: + break - results = collections.defaultdict(lambda: ([], [])) + return self.context({step.name: sink}) - for reader, ctx in context.table_iter(a): - results[ctx.eval_tuple(a_key)][0].append(reader.row) - for reader, ctx in context.table_iter(b): - results[ctx.eval_tuple(b_key)][1].append(reader.row) + def nested_loop_join(self, _join, source_context, join_context): + table = Table(source_context.columns + join_context.columns) - table = Table(context.tables[a].columns + context.tables[b].columns) - for a_group, b_group in results.values(): - for a_row, b_row in itertools.product(a_group, b_group): - table.append(a_row + b_row) + for reader_a, _ in source_context: + for reader_b, _ in join_context: + table.append(reader_a.row + reader_b.row) return table - def sort_merge_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) - - context.sort(a, a_key) - context.sort(b, b_key) - - a_i = 0 - b_i = 0 - a_n = len(context.tables[a]) - b_n = len(context.tables[b]) - - table = Table(context.tables[a].columns + context.tables[b].columns) - - def get_key(source, key, i): - context.set_index(source, i) - return context.eval_tuple(key) + def hash_join(self, join, source_context, join_context): + source_key = self.generate_tuple(join["source_key"]) + join_key = self.generate_tuple(join["join_key"]) - while a_i < a_n and b_i < b_n: - key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) - - a_group = [] - - while a_i < a_n and key == get_key(a, a_key, a_i): - a_group.append(context[a].row) - a_i += 1 + results = collections.defaultdict(lambda: ([], [])) - b_group = [] + for reader, ctx in source_context: + results[ctx.eval_tuple(source_key)][0].append(reader.row) + for reader, ctx in join_context: + results[ctx.eval_tuple(join_key)][1].append(reader.row) - while b_i < b_n and key == get_key(b, b_key, b_i): - b_group.append(context[b].row) - b_i += 1 + table = Table(source_context.columns + join_context.columns) + for a_group, b_group in results.values(): for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) @@ -238,16 +217,18 @@ class PythonExecutor: aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) - context.sort(source, group_by) - - if step.operands: + if operands: source_table = context.tables[source] operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) - context = self.context({source: operand_table}) + context = self.context( + {None: operand_table, **{table: operand_table for table in context.tables}} + ) + + context.sort(group_by) group = None start = 0 @@ -256,15 +237,15 @@ class PythonExecutor: table = self.table(step.group + step.aggregations) for i in range(length): - context.set_index(source, i) + context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 if i == length - 1: - context.set_range(source, start, end - 1) + context.set_range(start, end - 1) elif key != group: - context.set_range(source, start, end - 2) + context.set_range(start, end - 2) else: continue @@ -272,13 +253,32 @@ class PythonExecutor: group = key start = end - 2 - return self.scan(step, self.context({source: table})) + context = self.context({step.name: table, **{name: table for name in context.tables}}) + + if step.projections: + return self.scan(step, context) + return context def sort(self, step, context): - table = list(context.tables)[0] - key = self.generate_tuple(step.key) - context.sort(table, key) - return self.scan(step, context) + projections = self.generate_tuple(step.projections) + + sink = self.table(step.projections) + + for reader, ctx in context: + sink.append(ctx.eval_tuple(projections)) + + context = self.context( + { + None: sink, + **{table: sink for table in context.tables}, + } + ) + context.sort(self.generate_tuple(step.key)) + + if not math.isinf(step.limit): + context.table.rows = context.table.rows[0 : step.limit] + + return self.context({step.name: context.table}) def _cast_py(self, expression): @@ -293,7 +293,7 @@ def _cast_py(self, expression): def _column_py(self, expression): - table = self.sql(expression, "table") + table = self.sql(expression, "table") or None this = self.sql(expression, "this") return f"scope[{table}][{this}]" @@ -319,10 +319,10 @@ def _ordered_py(self, expression): class Python(Dialect): - class Tokenizer(Tokenizer): - ESCAPE = "\\" + class Tokenizer(tokens.Tokenizer): + ESCAPES = ["\\"] - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 80674cb..6796740 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,10 +1,12 @@ class Table: - def __init__(self, *columns, rows=None): - self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + def __init__(self, columns, rows=None, column_range=None): + self.columns = tuple(columns) + self.column_range = column_range + self.reader = RowReader(self.columns, self.column_range) + self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) - self.reader = RowReader(self.columns) self.range_reader = RangeReader(self) def append(self, row): @@ -29,15 +31,22 @@ class Table: return self.reader def __repr__(self): - widths = {column: len(column) for column in self.columns} - lines = [" ".join(column for column in self.columns)] + columns = tuple( + column + for i, column in enumerate(self.columns) + if not self.column_range or i in self.column_range + ) + widths = {column: len(column) for column in columns} + lines = [" ".join(column for column in columns)] for i, row in enumerate(self): if i > 10: break lines.append( - " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns + ) ) return "\n".join(lines) @@ -70,8 +79,10 @@ class RangeReader: class RowReader: - def __init__(self, columns): - self.columns = {column: i for i, column in enumerate(columns)} + def __init__(self, columns, column_range=None): + self.columns = { + column: i for i, column in enumerate(columns) if not column_range or i in column_range + } self.row = None def __getitem__(self, column): diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1691d85..57a2c88 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import datetime import numbers import re +import typing as t from collections import deque from copy import deepcopy from enum import auto @@ -9,12 +12,15 @@ from sqlglot.errors import ParseError from sqlglot.helper import ( AutoName, camel_to_snake_case, - ensure_list, - list_get, + ensure_collection, + seq_get, split_num_words, subclasses, ) +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import Dialect + class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -35,27 +41,30 @@ class Expression(metaclass=_Expression): or optional (False). """ - key = None + key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type") + __slots__ = ("args", "parent", "arg_key", "type", "comment") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None + self.comment = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) is type(other) and _norm_args(self) == _norm_args(other) - def __hash__(self): + def __hash__(self) -> int: return hash( ( self.key, - tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), + tuple( + (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items() + ), ) ) @@ -79,6 +88,19 @@ class Expression(metaclass=_Expression): return field.this return "" + def find_comment(self, key: str) -> str: + """ + Finds the comment that is attached to a specified child node. + + Args: + key: the key of the target child node (e.g. "this", "expression", etc). + + Returns: + The comment attached to the child node, or the empty string, if it doesn't exist. + """ + field = self.args.get(key) + return field.comment if isinstance(field, Expression) else "" + @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -114,7 +136,10 @@ class Expression(metaclass=_Expression): return self.alias or self.name def __deepcopy__(self, memo): - return self.__class__(**deepcopy(self.args)) + copy = self.__class__(**deepcopy(self.args)) + copy.comment = self.comment + copy.type = self.type + return copy def copy(self): new = deepcopy(self) @@ -249,9 +274,7 @@ class Expression(metaclass=_Expression): return for k, v in self.args.items(): - nodes = ensure_list(v) - - for node in nodes: + for node in ensure_collection(v): if isinstance(node, Expression): yield from node.dfs(self, k, prune) @@ -274,9 +297,7 @@ class Expression(metaclass=_Expression): if isinstance(item, Expression): for k, v in item.args.items(): - nodes = ensure_list(v) - - for node in nodes: + for node in ensure_collection(v): if isinstance(node, Expression): queue.append((node, item, k)) @@ -319,7 +340,7 @@ class Expression(metaclass=_Expression): def __repr__(self): return self.to_s() - def sql(self, dialect=None, **opts): + def sql(self, dialect: Dialect | str | None = None, **opts) -> str: """ Returns SQL string representation of this tree. @@ -335,7 +356,7 @@ class Expression(metaclass=_Expression): return Dialect.get_or_raise(dialect)().generate(self, **opts) - def to_s(self, hide_missing=True, level=0): + def to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" indent += "".join([" "] * level) left = f"({self.key.upper()} " @@ -343,11 +364,13 @@ class Expression(metaclass=_Expression): args = { k: ", ".join( v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) - for v in ensure_list(vs) + for v in ensure_collection(vs) if v is not None ) for k, vs in self.args.items() } + args["comment"] = self.comment + args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} right = ", ".join(f"{k}: {v}" for k, v in args.items()) @@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable): pass -class Annotation(Expression): - arg_types = { - "this": True, - "expression": True, - } - - @property - def alias(self): - return self.expression.alias_or_name - - class Cache(Expression): arg_types = { "with": False, @@ -623,6 +635,38 @@ class Describe(Expression): pass +class Set(Expression): + arg_types = {"expressions": True} + + +class SetItem(Expression): + arg_types = { + "this": True, + "kind": False, + "collate": False, # MySQL SET NAMES statement + } + + +class Show(Expression): + arg_types = { + "this": True, + "target": False, + "offset": False, + "limit": False, + "like": False, + "where": False, + "db": False, + "full": False, + "mutex": False, + "query": False, + "channel": False, + "global": False, + "log": False, + "position": False, + "types": False, + } + + class UserDefinedFunction(Expression): arg_types = {"this": True, "expressions": False} @@ -864,18 +908,20 @@ class Literal(Condition): def __eq__(self, other): return ( - isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] + isinstance(other, Literal) + and self.this == other.this + and self.args["is_string"] == other.args["is_string"] ) def __hash__(self): return hash((self.key, self.this, self.args["is_string"])) @classmethod - def number(cls, number): + def number(cls, number) -> Literal: return cls(this=str(number), is_string=False) @classmethod - def string(cls, string): + def string(cls, string) -> Literal: return cls(this=str(string), is_string=True) @@ -1087,7 +1133,7 @@ class Properties(Expression): } @classmethod - def from_dict(cls, properties_dict): + def from_dict(cls, properties_dict) -> Properties: expressions = [] for key, value in properties_dict.items(): property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) @@ -1323,7 +1369,7 @@ class Select(Subqueryable): **QUERY_MODIFIERS, } - def from_(self, *expressions, append=True, dialect=None, copy=True, **opts): + def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the FROM expression. @@ -1356,7 +1402,7 @@ class Select(Subqueryable): **opts, ) - def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the GROUP BY expression. @@ -1392,7 +1438,7 @@ class Select(Subqueryable): **opts, ) - def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the ORDER BY expression. @@ -1425,7 +1471,7 @@ class Select(Subqueryable): **opts, ) - def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the SORT BY expression. @@ -1458,7 +1504,7 @@ class Select(Subqueryable): **opts, ) - def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the CLUSTER BY expression. @@ -1491,7 +1537,7 @@ class Select(Subqueryable): **opts, ) - def limit(self, expression, dialect=None, copy=True, **opts): + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: """ Set the LIMIT expression. @@ -1522,7 +1568,7 @@ class Select(Subqueryable): **opts, ) - def offset(self, expression, dialect=None, copy=True, **opts): + def offset(self, expression, dialect=None, copy=True, **opts) -> Select: """ Set the OFFSET expression. @@ -1553,7 +1599,7 @@ class Select(Subqueryable): **opts, ) - def select(self, *expressions, append=True, dialect=None, copy=True, **opts): + def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the SELECT expressions. @@ -1583,7 +1629,7 @@ class Select(Subqueryable): **opts, ) - def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts): + def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the LATERAL expressions. @@ -1626,7 +1672,7 @@ class Select(Subqueryable): dialect=None, copy=True, **opts, - ): + ) -> Select: """ Append to or set the JOIN expressions. @@ -1672,7 +1718,7 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: - natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore if natural: join.set("natural", True) if side: @@ -1681,12 +1727,12 @@ class Select(Subqueryable): join.set("kind", kind.text) if on: - on = and_(*ensure_list(on), dialect=dialect, **opts) + on = and_(*ensure_collection(on), dialect=dialect, **opts) join.set("on", on) if using: join = _apply_list_builder( - *ensure_list(using), + *ensure_collection(using), instance=join, arg="using", append=append, @@ -1705,7 +1751,7 @@ class Select(Subqueryable): **opts, ) - def where(self, *expressions, append=True, dialect=None, copy=True, **opts): + def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the WHERE expressions. @@ -1737,7 +1783,7 @@ class Select(Subqueryable): **opts, ) - def having(self, *expressions, append=True, dialect=None, copy=True, **opts): + def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the HAVING expressions. @@ -1769,7 +1815,7 @@ class Select(Subqueryable): **opts, ) - def distinct(self, distinct=True, copy=True): + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -1788,7 +1834,7 @@ class Select(Subqueryable): instance.set("distinct", Distinct() if distinct else None) return instance - def ctas(self, table, properties=None, dialect=None, copy=True, **opts): + def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create: """ Convert this expression to a CREATE TABLE AS statement. @@ -1826,11 +1872,11 @@ class Select(Subqueryable): ) @property - def named_selects(self): + def named_selects(self) -> t.List[str]: return [e.alias_or_name for e in self.expressions if e.alias_or_name] @property - def selects(self): + def selects(self) -> t.List[Expression]: return self.expressions @@ -1910,12 +1956,16 @@ class Parameter(Expression): pass +class SessionParameter(Expression): + arg_types = {"this": True, "kind": False} + + class Placeholder(Expression): arg_types = {"this": False} class Null(Condition): - arg_types = {} + arg_types: t.Dict[str, t.Any] = {} class Boolean(Condition): @@ -1936,6 +1986,7 @@ class DataType(Expression): NVARCHAR = auto() TEXT = auto() BINARY = auto() + VARBINARY = auto() INT = auto() TINYINT = auto() SMALLINT = auto() @@ -1975,7 +2026,7 @@ class DataType(Expression): UNKNOWN = auto() # Sentinel value, useful for type annotation @classmethod - def build(cls, dtype, **kwargs): + def build(cls, dtype, **kwargs) -> DataType: return DataType( this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], **kwargs, @@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate): pass +class NullSafeEQ(Binary, Predicate): + pass + + +class NullSafeNEQ(Binary, Predicate): + pass + + +class Distance(Binary): + pass + + class Escape(Binary): pass @@ -2101,15 +2164,11 @@ class Is(Binary, Predicate): pass -class Like(Binary, Predicate): - pass - - -class SimilarTo(Binary, Predicate): - pass +class Kwarg(Binary): + """Kwarg in special functions like func(kwarg => y).""" -class Distance(Binary): +class Like(Binary, Predicate): pass @@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate): pass +class SimilarTo(Binary, Predicate): + pass + + class Sub(Binary): pass @@ -2189,7 +2252,13 @@ class Distinct(Expression): class In(Predicate): - arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} + arg_types = { + "this": True, + "expressions": False, + "query": False, + "unnest": False, + "field": False, + } class TimeUnit(Expression): @@ -2255,7 +2324,9 @@ class Func(Condition): @classmethod def sql_names(cls): if cls is Func: - raise NotImplementedError("SQL name is only supported by concrete function implementations") + raise NotImplementedError( + "SQL name is only supported by concrete function implementations" + ) if not hasattr(cls, "_sql_names"): cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} -class DateTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} +class DateTrunc(Func): + arg_types = {"this": True, "expression": True, "zone": False} class DatetimeAdd(Func, TimeUnit): @@ -2791,6 +2862,10 @@ class Year(Func): pass +class Use(Expression): + pass + + def _norm_args(expression): args = {} @@ -2822,7 +2897,7 @@ def maybe_parse( dialect=None, prefix=None, **opts, -): +) -> t.Optional[Expression]: """Gracefully handle a possible string or expression. Example: @@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts): return Except(this=left, expression=right, distinct=distinct) -def select(*expressions, dialect=None, **opts): +def select(*expressions, dialect=None, **opts) -> Select: """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts): return Select().select(*expressions, dialect=dialect, **opts) -def from_(*expressions, dialect=None, **opts): +def from_(*expressions, dialect=None, **opts) -> Select: """ Initializes a syntax tree from a FROM expression. @@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts): return Select().from_(*expressions, dialect=dialect, **opts) -def update(table, properties, where=None, from_=None, dialect=None, **opts): +def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update: """ Creates an update statement. @@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) update.set( "expressions", - [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()], + [ + EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) + for k, v in properties.items() + ], ) if from_: update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) @@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): return update -def delete(table, where=None, dialect=None, **opts): +def delete(table, where=None, dialect=None, **opts) -> Delete: """ Builds a delete statement. @@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts): ) -def condition(expression, dialect=None, **opts): +def condition(expression, dialect=None, **opts) -> Condition: """ Initialize a logical condition expression. @@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts): Returns: Condition: the expression """ - return maybe_parse( + return maybe_parse( # type: ignore expression, into=Condition, dialect=dialect, @@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts): ) -def and_(*expressions, dialect=None, **opts): +def and_(*expressions, dialect=None, **opts) -> And: """ Combine multiple conditions with an AND logical operator. @@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts): return _combine(expressions, And, dialect, **opts) -def or_(*expressions, dialect=None, **opts): +def or_(*expressions, dialect=None, **opts) -> Or: """ Combine multiple conditions with an OR logical operator. @@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts): return _combine(expressions, Or, dialect, **opts) -def not_(expression, dialect=None, **opts): +def not_(expression, dialect=None, **opts) -> Not: """ Wrap a condition with a NOT operator. @@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts): return Not(this=_wrap_operator(this)) -def paren(expression): +def paren(expression) -> Paren: return Paren(this=expression) SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") -def to_identifier(alias, quoted=None): +def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: if alias is None: return None if isinstance(alias, Identifier): @@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None): return identifier -def to_table(sql_path: str, **kwargs) -> Table: +def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. - If a table is passed in then that table is returned. Args: - sql_path(str|Table): `[catalog].[schema].[table]` string + sql_path: a `[catalog].[schema].[table]` string. + Returns: - Table: A table expression + A table expression. """ if sql_path is None or isinstance(sql_path, Table): return sql_path @@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts): return Select().from_(expression, dialect=dialect, **opts) -def column(col, table=None, quoted=None): +def column(col, table=None, quoted=None) -> Column: """ Build a Column. Args: @@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None): ) -def table_(table, db=None, catalog=None, quoted=None, alias=None): +def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. Args: @@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None): ) -def values(values, alias=None): +def values(values, alias=None) -> Values: """Build VALUES statement. Example: @@ -3449,7 +3527,7 @@ def values(values, alias=None): ) -def convert(value): +def convert(value) -> Expression: """Convert a python value into an expression object. Raises an error if a conversion is not possible. @@ -3500,15 +3578,14 @@ def replace_children(expression, fun): for cn in child_nodes: if isinstance(cn, Expression): - cns = ensure_list(fun(cn)) - for child_node in cns: + for child_node in ensure_collection(fun(cn)): new_child_nodes.append(child_node) child_node.parent = expression child_node.arg_key = k else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) + expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) def column_table_names(expression): @@ -3529,7 +3606,7 @@ def column_table_names(expression): return list(dict.fromkeys(column.table for column in expression.find_all(Column))) -def table_name(table): +def table_name(table) -> str: """Get the full name of a table as a string. Args: @@ -3546,6 +3623,9 @@ def table_name(table): table = maybe_parse(table, into=Table) + if not table: + raise ValueError(f"Cannot parse {table}") + return ".".join( part for part in ( diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ca14425..11d9073 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import logging +import re +import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors @@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") +NEWLINE_RE = re.compile("\r\n?|\n") + class Generator: """ @@ -47,8 +53,7 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - annotations: Whether or not to show annotations in the SQL when `pretty` is True. - Annotations can only be shown in pretty mode otherwise they may clobber resulting sql. + comments: Whether or not to preserve comments in the ouput SQL code. Default: True """ @@ -65,14 +70,16 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } - # whether 'CREATE ... TRANSIENT ... TABLE' is allowed - # can override in dialects + # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed CREATE_TRANSIENT = False - # whether or not null ordering is supported in order by + + # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True - # always do union distinct or union all + + # Always do union distinct or union all EXPLICIT_UNION = False - # wrap derived values in parens, usually standard but spark doesn't support it + + # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True TYPE_MAPPING = { @@ -80,7 +87,7 @@ class Generator: exp.DataType.Type.NVARCHAR: "VARCHAR", } - TOKEN_MAPPING = {} + TOKEN_MAPPING: t.Dict[TokenType, str] = {} STRUCT_DELIMITER = ("<", ">") @@ -96,6 +103,8 @@ class Generator: exp.TableFormatProperty, } + WITH_SEPARATED_COMMENTS = (exp.Select,) + __slots__ = ( "time_mapping", "time_trie", @@ -122,7 +131,7 @@ class Generator: "_escaped_quote_end", "_leading_comma", "_max_text_width", - "_annotations", + "_comments", ) def __init__( @@ -148,7 +157,7 @@ class Generator: max_unsupported=3, leading_comma=False, max_text_width=80, - annotations=True, + comments=True, ): import sqlglot @@ -177,7 +186,7 @@ class Generator: self._escaped_quote_end = self.escape + self.quote_end self._leading_comma = leading_comma self._max_text_width = max_text_width - self._annotations = annotations + self._comments = comments def generate(self, expression): """ @@ -204,7 +213,6 @@ class Generator: return sql def unsupported(self, message): - if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) @@ -215,9 +223,31 @@ class Generator: def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" + def maybe_comment(self, sql, expression, single_line=False): + comment = expression.comment if self._comments else None + + if not comment: + return sql + + comment = " " + comment if comment[0].strip() else comment + comment = comment + " " if comment[-1].strip() else comment + + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"/*{comment}*/{self.sep()}{sql}" + + if not self.pretty: + return f"{sql} /*{comment}*/" + + if not NEWLINE_RE.search(comment): + return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + + return f"/*{comment}*/\n{sql}" + def wrap(self, expression): this_sql = self.indent( - self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), + self.sql(expression) + if isinstance(expression, (exp.Select, exp.Union)) + else self.sql(expression, "this"), level=1, pad=0, ) @@ -251,7 +281,7 @@ class Generator: for i, line in enumerate(lines) ) - def sql(self, expression, key=None): + def sql(self, expression, key=None, comment=True): if not expression: return "" @@ -264,29 +294,24 @@ class Generator: transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): - return transform(self, expression) - if transform: - return transform - - if not isinstance(expression, exp.Expression): + sql = transform(self, expression) + elif transform: + sql = transform + elif isinstance(expression, exp.Expression): + exp_handler_name = f"{expression.key}_sql" + + if hasattr(self, exp_handler_name): + sql = getattr(self, exp_handler_name)(expression) + elif isinstance(expression, exp.Func): + sql = self.function_fallback_sql(expression) + elif isinstance(expression, exp.Property): + sql = self.property_sql(expression) + else: + raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") + else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - exp_handler_name = f"{expression.key}_sql" - if hasattr(self, exp_handler_name): - return getattr(self, exp_handler_name)(expression) - - if isinstance(expression, exp.Func): - return self.function_fallback_sql(expression) - - if isinstance(expression, exp.Property): - return self.property_sql(expression) - - raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") - - def annotation_sql(self, expression): - if self._annotations and self.pretty: - return f"{self.sql(expression, 'expression')} # {expression.name}" - return self.sql(expression, "expression") + return self.maybe_comment(sql, expression) if self._comments and comment else sql def uncache_sql(self, expression): table = self.sql(expression, "this") @@ -371,7 +396,9 @@ class Generator: expression_sql = self.sql(expression, "expression") expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" - transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" + transient = ( + " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" + ) replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" @@ -434,7 +461,9 @@ class Generator: def delete_sql(self, expression): this = self.sql(expression, "this") using_sql = ( - f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else "" + f" USING {self.expressions(expression, 'using', sep=', USING ')}" + if expression.args.get("using") + else "" ) where_sql = self.sql(expression, "where") sql = f"DELETE FROM {this}{using_sql}{where_sql}" @@ -481,15 +510,18 @@ class Generator: return f"{this} ON {table} {columns}" def identifier_sql(self, expression): - value = expression.name - value = value.lower() if self.normalize else value + text = expression.name + text = text.lower() if self.normalize else text if expression.args.get("quoted") or self.identify: - return f"{self.identifier_start}{value}{self.identifier_end}" - return value + text = f"{self.identifier_start}{text}{self.identifier_end}" + return text def partition_sql(self, expression): keys = csv( - *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] + *[ + f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name + for prop in expression.this + ] ) return f"PARTITION({keys})" @@ -504,9 +536,9 @@ class Generator: elif p_class in self.ROOT_PROPERTIES: root_properties.append(p) - return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( - exp.Properties(expressions=with_properties) - ) + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) def root_properties(self, properties): if properties.expressions: @@ -551,7 +583,9 @@ class Generator: this = f"{this}{self.sql(expression, 'this')}" exists = " IF EXISTS " if expression.args.get("exists") else " " - partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" + partition_sql = ( + self.sql(expression, "partition") if expression.args.get("partition") else "" + ) expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" @@ -669,7 +703,9 @@ class Generator: def group_sql(self, expression): group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) - grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" + grouping_sets = ( + f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" + ) cube = self.expressions(expression, key="cube", indent=False) cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" rollup = self.expressions(expression, key="rollup", indent=False) @@ -711,10 +747,10 @@ class Generator: this_sql = self.sql(expression, "this") return f"{expression_sql}{op_sql} {this_sql}{on_sql}" - def lambda_sql(self, expression): + def lambda_sql(self, expression, arrow_sep="->"): args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args - return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}") + return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") def lateral_sql(self, expression): this = self.sql(expression, "this") @@ -748,7 +784,7 @@ class Generator: if self._replace_backslash: text = text.replace("\\", "\\\\") text = text.replace(self.quote_end, self._escaped_quote_end) - return f"{self.quote_start}{text}{self.quote_end}" + text = f"{self.quote_start}{text}{self.quote_end}" return text def loaddata_sql(self, expression): @@ -796,13 +832,21 @@ class Generator: sort_order = " DESC" if desc else "" nulls_sort_change = "" - if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): + if nulls_first and ( + (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last + ): nulls_sort_change = " NULLS FIRST" - elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: + elif ( + nulls_last + and ((asc and nulls_are_small) or (desc and nulls_are_large)) + and not nulls_are_last + ): nulls_sort_change = " NULLS LAST" if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") + self.unsupported( + "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" + ) nulls_sort_change = "" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" @@ -835,7 +879,7 @@ class Generator: sql = self.query_modifiers( expression, f"SELECT{hint}{distinct}{expressions}", - self.sql(expression, "from"), + self.sql(expression, "from", comment=False), ) return self.prepend_ctes(expression, sql) @@ -858,6 +902,13 @@ class Generator: def parameter_sql(self, expression): return f"@{self.sql(expression, 'this')}" + def sessionparameter_sql(self, expression): + this = self.sql(expression, "this") + kind = expression.text("kind") + if kind: + kind = f"{kind}." + return f"@@{kind}{this}" + def placeholder_sql(self, expression): return f":{expression.name}" if expression.name else "?" @@ -931,7 +982,10 @@ class Generator: def window_spec_sql(self, expression): kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") - end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" + end = ( + csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") + or "CURRENT ROW" + ) return f"{kind} BETWEEN {start} AND {end}" def withingroup_sql(self, expression): @@ -1020,7 +1074,9 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))) + return self.case_sql( + exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) + ) def in_sql(self, expression): query = expression.args.get("query") @@ -1196,6 +1252,12 @@ class Generator: def neq_sql(self, expression): return self.binary(expression, "<>") + def nullsafeeq_sql(self, expression): + return self.binary(expression, "IS NOT DISTINCT FROM") + + def nullsafeneq_sql(self, expression): + return self.binary(expression, "IS DISTINCT FROM") + def or_sql(self, expression): return self.connector_sql(expression, "OR") @@ -1205,6 +1267,9 @@ class Generator: def trycast_sql(self, expression): return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + def use_sql(self, expression): + return f"USE {self.sql(expression, 'this')}" + def binary(self, expression, op): return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" @@ -1240,17 +1305,27 @@ class Generator: if flat: return sep.join(self.sql(e) for e in expressions) - sql = (self.sql(e) for e in expressions) - # the only time leading_comma changes the output is if pretty print is enabled - if self._leading_comma and self.pretty: - pad = " " * self.pad - expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql)) - else: - expressions = self.sep(sep).join(sql) + num_sqls = len(expressions) + + # These are calculated once in case we have the leading_comma / pretty option set, correspondingly + pad = " " * self.pad + stripped_sep = sep.strip() - if indent: - return self.indent(expressions, skip_first=False) - return expressions + result_sqls = [] + for i, e in enumerate(expressions): + sql = self.sql(e, comment=False) + comment = self.maybe_comment("", e, single_line=True) + + if self.pretty: + if self._leading_comma: + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + else: + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + else: + result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + + result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) + return self.indent(result_sqls, skip_first=False) if indent else result_sqls def op_expressions(self, op, expression, flat=False): expressions_sql = self.expressions(expression, flat=flat) @@ -1264,7 +1339,9 @@ class Generator: def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) - return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") + return self.query_modifiers( + expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" + ) def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) @@ -1283,3 +1360,6 @@ class Generator: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"{this}({expressions})" + + def kwarg_sql(self, expression): + return self.binary(expression, "=>") diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 42965d1..379c2e7 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -1,48 +1,125 @@ +from __future__ import annotations + import inspect import logging import re import sys import typing as t +from collections.abc import Collection from contextlib import contextmanager from copy import copy from enum import Enum +if t.TYPE_CHECKING: + from sqlglot.expressions import Expression, Table + + T = t.TypeVar("T") + E = t.TypeVar("E", bound=Expression) + CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") +PYTHON_VERSION = sys.version_info[:2] logger = logging.getLogger("sqlglot") class AutoName(Enum): - def _generate_next_value_(name, _start, _count, _last_values): + """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name.""" + + def _generate_next_value_(name, _start, _count, _last_values): # type: ignore return name -def list_get(arr, index): +def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: + """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" try: - return arr[index] + return seq[index] except IndexError: return None +@t.overload +def ensure_list(value: t.Collection[T]) -> t.List[T]: + ... + + +@t.overload +def ensure_list(value: T) -> t.List[T]: + ... + + def ensure_list(value): + """ + Ensures that a value is a list, otherwise casts or wraps it into one. + + Args: + value: the value of interest. + + Returns: + The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. + """ if value is None: return [] - return value if isinstance(value, (list, tuple, set)) else [value] + elif isinstance(value, (list, tuple)): + return list(value) + + return [value] + + +@t.overload +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: + ... -def csv(*args, sep=", "): +@t.overload +def ensure_collection(value: T) -> t.Collection[T]: + ... + + +def ensure_collection(value): + """ + Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. + + Args: + value: the value of interest. + + Returns: + The value if it's a collection, or else the value wrapped in a list. + """ + if value is None: + return [] + return ( + value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] + ) + + +def csv(*args, sep: str = ", ") -> str: + """ + Formats any number of string arguments as CSV. + + Args: + args: the string arguments to format. + sep: the argument separator. + + Returns: + The arguments formatted as a CSV string. + """ return sep.join(arg for arg in args if arg) -def subclasses(module_name, classes, exclude=()): +def subclasses( + module_name: str, + classes: t.Type | t.Tuple[t.Type, ...], + exclude: t.Type | t.Tuple[t.Type, ...] = (), +) -> t.List[t.Type]: """ - Returns a list of all subclasses for a specified class set, posibly excluding some of them. + Returns all subclasses for a collection of classes, possibly excluding some of them. Args: - module_name (str): The name of the module to search for subclasses in. - classes (type|tuple[type]): Class(es) we want to find the subclasses of. - exclude (type|tuple[type]): Class(es) we want to exclude from the returned list. + module_name: the name of the module to search for subclasses in. + classes: class(es) we want to find the subclasses of. + exclude: class(es) we want to exclude from the returned list. + Returns: - A list of all the target subclasses. + The target subclasses. """ return [ obj @@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()): ] -def apply_index_offset(expressions, offset): +def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: + """ + Applies an offset to a given integer literal expression. + + Args: + expressions: the expression the offset will be applied to, wrapped in a list. + offset: the offset that will be applied. + + Returns: + The original expression with the offset applied to it, wrapped in a list. If the provided + `expressions` argument contains more than one expressions, it's returned unaffected. + """ if not offset or len(expressions) != 1: return expressions @@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset): logger.warning("Applying array index offset (%s)", offset) expression.args["this"] = str(int(expression.args["this"]) + offset) return [expression] + return expressions -def camel_to_snake_case(name): +def camel_to_snake_case(name: str) -> str: + """Converts `name` from camelCase to snake_case and returns the result.""" return CAMEL_CASE_PATTERN.sub("_", name).upper() -def while_changing(expression, func): +def while_changing( + expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E] +) -> E: + """ + Applies a transformation to a given expression until a fix point is reached. + + Args: + expression: the expression to be transformed. + func: the transformation to be applied. + + Returns: + The transformed expression. + """ while True: start = hash(expression) expression = func(expression) @@ -80,10 +182,19 @@ def while_changing(expression, func): return expression -def tsort(dag): +def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: + """ + Sorts a given directed acyclic graph in topological order. + + Args: + dag: the graph to be sorted. + + Returns: + A list that contains all of the graph's nodes in topological order. + """ result = [] - def visit(node, visited): + def visit(node: T, visited: t.Set[T]) -> None: if node in result: return if node in visited: @@ -103,10 +214,8 @@ def tsort(dag): return result -def open_file(file_name): - """ - Open a file that may be compressed as gzip and return in newline mode. - """ +def open_file(file_name: str) -> t.TextIO: + """Open a file that may be compressed as gzip and return it in universal newline mode.""" with open(file_name, "rb") as f: gzipped = f.read(2) == b"\x1f\x8b" @@ -119,14 +228,14 @@ def open_file(file_name): @contextmanager -def csv_reader(table): +def csv_reader(table: Table) -> t.Any: """ - Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) + Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. Args: - table (exp.Table): A table expression with an anonymous function READ_CSV in it + table: a `Table` expression with an anonymous function `READ_CSV` in it. - Returns: + Yields: A python csv reader. """ file, *args = table.this.expressions @@ -147,13 +256,16 @@ def csv_reader(table): file.close() -def find_new_name(taken, base): +def find_new_name(taken: t.Sequence[str], base: str) -> str: """ Searches for a new name. Args: - taken (Sequence[str]): set of taken names - base (str): base name to alter + taken: a collection of taken names. + base: base name to alter. + + Returns: + The new, available name. """ if base not in taken: return base @@ -163,22 +275,26 @@ def find_new_name(taken, base): while new in taken: i += 1 new = f"{base}_{i}" + return new -def object_to_dict(obj, **kwargs): +def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: + """Returns a dictionary created from an object's attributes.""" return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} -def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]: +def split_num_words( + value: str, sep: str, min_num_words: int, fill_from_start: bool = True +) -> t.List[t.Optional[str]]: """ - Perform a split on a value and return N words as a result with None used for words that don't exist. + Perform a split on a value and return N words as a result with `None` used for words that don't exist. Args: - value: The value to be split - sep: The value to use to split on - min_num_words: The minimum number of words that are going to be in the result - fill_from_start: Indicates that if None values should be inserted at the start or end of the list + value: the value to be split. + sep: the value to use to split on. + min_num_words: the minimum number of words that are going to be in the result. + fill_from_start: indicates that if `None` values should be inserted at the start or end of the list. Examples: >>> split_num_words("db.table", ".", 3) @@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table'] + + Returns: + The list of words returned by `split`, possibly augmented by a number of `None` values. """ words = value.split(sep) if fill_from_start: @@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b def is_iterable(value: t.Any) -> bool: """ - Checks if the value is an iterable but does not include strings and bytes + Checks if the value is an iterable, excluding the types `str` and `bytes`. Examples: >>> is_iterable([1,2]) @@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool: False Args: - value: The value to check if it is an interable + value: the value to check if it is an iterable. - Returns: Bool indicating if it is an iterable + Returns: + A `bool` value indicating if it is an iterable. """ return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) -def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: +def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]: """ - Flattens a list that can contain both iterables and non-iterable elements + Flattens an iterable that can contain both iterable and non-iterable elements. Objects of + type `str` and `bytes` are not regarded as iterables. Examples: - >>> list(flatten([[1, 2], 3])) - [1, 2, 3] + >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) + [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3] Args: - values: The value to be flattened + values: the value to be flattened. - Returns: - Yields non-iterable elements (not including str or byte as iterable) + Yields: + Non-iterable elements in `values`. """ for value in values: if is_iterable(value): diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 30055bc..96331e2 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_list, subclasses +from sqlglot.helper import ensure_collection, ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -48,35 +48,65 @@ class TypeAnnotator: exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BIGINT + ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.CurrentTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.TimestampSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), + exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), + exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), + exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.GroupConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.ArrayConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -88,32 +118,52 @@ class TypeAnnotator: exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), + exp.RegexpLike: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BOOLEAN + ), exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.StrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } @@ -124,7 +174,11 @@ class TypeAnnotator: exp.DataType.Type.TEXT: set(), exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.NCHAR: { + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.TEXT, + }, exp.DataType.Type.CHAR: { exp.DataType.Type.NCHAR, exp.DataType.Type.VARCHAR, @@ -135,7 +189,11 @@ class TypeAnnotator: exp.DataType.Type.DOUBLE: set(), exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.BIGINT: { + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, exp.DataType.Type.INT: { exp.DataType.Type.BIGINT, exp.DataType.Type.DECIMAL, @@ -160,7 +218,10 @@ class TypeAnnotator: # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ exp.DataType.Type.TIMESTAMPLTZ: set(), exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.TIMESTAMP: { + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, exp.DataType.Type.DATETIME: { exp.DataType.Type.TIMESTAMP, exp.DataType.Type.TIMESTAMPTZ, @@ -219,7 +280,7 @@ class TypeAnnotator: def _annotate_args(self, expression): for value in expression.args.values(): - for v in ensure_list(value): + for v in ensure_collection(value): self._maybe_annotate(v) return expression @@ -243,7 +304,9 @@ class TypeAnnotator: if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: expression.type = exp.DataType.Type.NULL elif exp.DataType.Type.NULL in (left_type, right_type): - expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + expression.type = exp.DataType.build( + "NULLABLE", expressions=exp.DataType.build("BOOLEAN") + ) else: expression.type = exp.DataType.Type.BOOLEAN elif isinstance(expression, (exp.Condition, exp.Predicate)): @@ -276,3 +339,17 @@ class TypeAnnotator: def _annotate_with_type(self, expression, target_type): expression.type = target_type return self._annotate_args(expression) + + def _annotate_by_args(self, expression, *args): + self._annotate_args(expression) + expressions = [] + for arg in args: + arg_expr = expression.args.get(arg) + expressions.extend(expr for expr in ensure_list(arg_expr) if expr) + + last_datatype = None + for expr in expressions: + last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) + + expression.type = last_datatype or exp.DataType.Type.UNKNOWN + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 0854336..29621af 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias): on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) else: on_clause_columns = set() - return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) + return any( + column for column in scope.source_columns(alias) if id(column) not in on_clause_columns + ) def _is_joined_on_all_unique_outputs(scope, join): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index e30c263..8704e90 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -45,7 +45,13 @@ def eliminate_subqueries(expression): # All table names are taken for scope in root.traverse(): - taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)}) + taken.update( + { + source.name: source + for _, source in scope.sources.items() + if isinstance(source, exp.Table) + } + ) # Map of Expression->alias # Existing CTES in the root expression. We'll use this for deduplication. @@ -70,7 +76,9 @@ def eliminate_subqueries(expression): new_ctes.append(cte_scope.expression.parent) # Now append the rest - for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes): + for scope in itertools.chain( + root.union_scopes, root.subquery_scopes, root.derived_table_scopes + ): for child_scope in scope.traverse(): new_cte = _eliminate(child_scope, existing_ctes, taken) if new_cte: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 70e4629..9ae4966 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): unmergable_window_columns = [ column for column in outer_scope.columns - if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc) + if column.find_ancestor( + exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc + ) ] window_expressions_in_unmergable = [ column @@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): and not ( isinstance(from_or_join, exp.From) and inner_select.args.get("where") - and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + and any( + j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) + ) ) and not _is_a_window_expression_in_unmergable_operation() ) @@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): if table.alias_or_name == node_to_replace.alias_or_name: table.set("this", exp.to_identifier(new_subquery.alias_or_name)) outer_scope.remove_source(alias) - outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + outer_scope.add_source( + new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] + ) def _merge_joins(outer_scope, inner_scope, from_or_join): @@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope): inner_scope (sqlglot.optimizer.scope.Scope) """ if ( - any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + any( + outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] + ) or len(outer_scope.selected_sources) != 1 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) ): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index ab30d7a..db538ef 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False): Returns: int: difference """ - return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) + return sum(_predicate_lengths(expression, dnf)) - ( + len(list(expression.find_all(exp.Connector))) + 1 + ) def _predicate_lengths(expression, dnf): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 0c74e36..40e4ab1 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -68,4 +68,8 @@ def normalize(expression): def other_table_names(join, exclude): - return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 5ad8f46..b2ed062 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames - rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + rule_kwargs = { + param: possible_kwargs[param] for param in rule_params if param in possible_kwargs + } expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 583d059..6364f65 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) if cnf_like: pushdown_cnf(predicates, sources, scope_ref_count) @@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition ) for name, node in nodes.items(): @@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: # We can't push down window expressions - has_window_expression = any(select for select in node.selects if select.find(exp.Window)) + has_window_expression = any( + select for select in node.selects if select.find(exp.Window) + ) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: + if ( + not node.args.get("group") + and scope_ref_count[id(source)] < 2 + and not has_window_expression + ): nodes[table] = node return nodes @@ -165,7 +181,7 @@ def replace_aliases(source, predicate): def _replace_alias(column): if isinstance(column, exp.Column) and column.name in aliases: - return aliases[column.name] + return aliases[column.name].copy() return column return predicate.transform(_replace_alias) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 5820851..abd9492 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections): def _remove_indexed_selections(scope, indexes_to_remove): - new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove] + new_selections = [ + selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove + ] if not new_selections: new_selections.append(DEFAULT_SELECTION) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ebee92a..69fe2b8 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver): # Determine whether each reference in the order by clause is to a column or an alias. for ordered in scope.find_all(exp.Ordered): for column in ordered.find_all(exp.Column): - if not column.table and column.parent is not ordered and column.name in resolver.all_columns: + if ( + not column.table + and column.parent is not ordered + and column.name in resolver.all_columns + ): columns_missing_from_scope.append(column) # Determine whether each reference in the having clause is to a column or an alias. for having in scope.find_all(exp.Having): for column in having.find_all(exp.Column): - if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns: + if ( + not column.table + and column.find_ancestor(exp.AggFunc) + and column.name in resolver.all_columns + ): columns_missing_from_scope.append(column) for column in columns_missing_from_scope: @@ -295,7 +303,9 @@ def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = [] - for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): if isinstance(selection, exp.Column): # convoluted setter because a simple selection.replace(alias) would require a copy alias_ = alias(exp.column(""), alias=selection.name) @@ -343,14 +353,18 @@ class _Resolver: (str) table name """ if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) return self._unambiguous_columns.get(column_name) @property def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) + self._all_columns = set( + column for columns in self._get_all_source_columns().values() for column in columns + ) return self._all_columns def get_source_columns(self, name, only_visible=False): @@ -377,7 +391,9 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: - self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } return self._source_columns def _get_unambiguous_columns(self, source_columns): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 5a75ee2..18848f3 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -226,7 +226,9 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] + external_columns = [ + column for scope in self.subquery_scopes for column in scope.external_columns + ] named_outputs = {e.alias_or_name for e in self.expression.expressions} @@ -278,7 +280,11 @@ class Scope: Returns: dict[str, Scope]: Mapping of source alias to Scope """ - return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte} + return { + alias: scope + for alias, scope in self.sources.items() + if isinstance(scope, Scope) and scope.is_cte + } @property def selects(self): @@ -307,7 +313,9 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] return self._external_columns @property diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c077906..d759e86 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -229,7 +229,9 @@ def simplify_literals(expression): operands.append(a) if len(operands) < size: - return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b): return TRUE if not_ else FALSE if a == NULL: return FALSE if not_ else TRUE + elif isinstance(expression, exp.NullSafeEQ): + if a == b: + return TRUE + elif isinstance(expression, exp.NullSafeNEQ): + if a == b: + return FALSE elif NULL in (a, b): return NULL @@ -357,7 +365,7 @@ def extract_date(cast): def extract_interval(interval): try: - from dateutil.relativedelta import relativedelta + from dateutil.relativedelta import relativedelta # type: ignore except ModuleNotFoundError: return None diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 11c6eba..f41a84e 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence): return if isinstance(predicate, exp.Binary): - key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) else: return @@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence): else: parent_predicate = _replace(parent_predicate, "TRUE") elif isinstance(parent_predicate, exp.All): - parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) elif isinstance(parent_predicate, exp.Any): if value.this in group_by: parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") @@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 79a1d90..bbea0e5 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import logging +import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_errors -from sqlglot.helper import apply_index_offset, ensure_list, list_get +from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType +from sqlglot.trie import in_trie, new_trie logger = logging.getLogger("sqlglot") @@ -20,7 +24,15 @@ def parse_var_map(args): ) -class Parser: +class _Parser(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) + klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) + return klass + + +class Parser(metaclass=_Parser): """ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` and produces a parsed syntax tree. @@ -45,16 +57,16 @@ class Parser: FUNCTIONS = { **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, "DATE_TO_DATE_STR": lambda args: exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "TIME_TO_TIME_STR": lambda args: exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( this=exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), start=exp.Literal.number(1), @@ -90,6 +102,7 @@ class Parser: TokenType.NVARCHAR, TokenType.TEXT, TokenType.BINARY, + TokenType.VARBINARY, TokenType.JSON, TokenType.INTERVAL, TokenType.TIMESTAMP, @@ -243,6 +256,7 @@ class Parser: EQUALITY = { TokenType.EQ: exp.EQ, TokenType.NEQ: exp.NEQ, + TokenType.NULLSAFE_EQ: exp.NullSafeEQ, } COMPARISON = { @@ -298,6 +312,21 @@ class Parser: TokenType.ANTI, } + LAMBDAS = { + TokenType.ARROW: lambda self, expressions: self.expression( + exp.Lambda, + this=self._parse_conjunction().transform( + self._replace_lambda, {node.name for node in expressions} + ), + expressions=expressions, + ), + TokenType.FARROW: lambda self, expressions: self.expression( + exp.Kwarg, + this=exp.Var(this=expressions[0].name), + expression=self._parse_conjunction(), + ), + } + COLUMN_OPERATORS = { TokenType.DOT: None, TokenType.DCOLON: lambda self, this, to: self.expression( @@ -362,20 +391,30 @@ class Parser: TokenType.DELETE: lambda self: self._parse_delete(), TokenType.CACHE: lambda self: self._parse_cache(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.USE: lambda self: self._parse_use(), } PRIMARY_PARSERS = { - TokenType.STRING: lambda _, token: exp.Literal.string(token.text), - TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), - TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), - TokenType.NULL: lambda *_: exp.Null(), - TokenType.TRUE: lambda *_: exp.Boolean(this=True), - TokenType.FALSE: lambda *_: exp.Boolean(this=False), - TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), - TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), - TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), - TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text), + TokenType.STRING: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=True + ), + TokenType.NUMBER: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=False + ), + TokenType.STAR: lambda self, _: self.expression( + exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + ), + TokenType.NULL: lambda self, _: self.expression(exp.Null), + TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), + TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), + TokenType.PARAMETER: lambda self, _: self.expression( + exp.Parameter, this=self._parse_var() or self._parse_primary() + ), + TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), + TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), + TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } RANGE_PARSERS = { @@ -411,16 +450,24 @@ class Parser: TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), - TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty), + TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment( + exp.TableFormatProperty + ), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), TokenType.EXECUTE: lambda self: self._parse_execute_as(), TokenType.DETERMINISTIC: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), - TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")), - TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")), - TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")), + TokenType.IMMUTABLE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + ), + TokenType.STABLE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("STABLE") + ), + TokenType.VOLATILE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + ), } CONSTRAINT_PARSERS = { @@ -450,7 +497,8 @@ class Parser: "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), + "window": lambda self: self._match(TokenType.WINDOW) + and self._parse_window(self._parse_id_var(), alias=True), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), @@ -459,6 +507,9 @@ class Parser: "offset": lambda self: self._parse_offset(), } + SHOW_PARSERS: t.Dict[str, t.Callable] = {} + SET_PARSERS: t.Dict[str, t.Callable] = {} + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) CREATABLES = { @@ -488,7 +539,9 @@ class Parser: "_curr", "_next", "_prev", - "_greedy_subqueries", + "_prev_comment", + "_show_trie", + "_set_trie", ) def __init__( @@ -519,7 +572,7 @@ class Parser: self._curr = None self._next = None self._prev = None - self._greedy_subqueries = False + self._prev_comment = None def parse(self, raw_tokens, sql=None): """ @@ -533,10 +586,12 @@ class Parser: Returns the list of syntax trees (:class:`~sqlglot.expressions.Expression`). """ - return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) + return self._parse( + parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql + ) def parse_into(self, expression_types, raw_tokens, sql=None): - for expression_type in ensure_list(expression_types): + for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: raise TypeError(f"No parser registered for {expression_type}") @@ -597,6 +652,9 @@ class Parser: def expression(self, exp_class, **kwargs): instance = exp_class(**kwargs) + if self._prev_comment: + instance.comment = self._prev_comment + self._prev_comment = None self.validate_expression(instance) return instance @@ -633,14 +691,16 @@ class Parser: return index - def _get_token(self, index): - return list_get(self._tokens, index) - def _advance(self, times=1): self._index += times - self._curr = self._get_token(self._index) - self._next = self._get_token(self._index + 1) - self._prev = self._get_token(self._index - 1) if self._index > 0 else None + self._curr = seq_get(self._tokens, self._index) + self._next = seq_get(self._tokens, self._index + 1) + if self._index > 0: + self._prev = self._tokens[self._index - 1] + self._prev_comment = self._prev.comment + else: + self._prev = None + self._prev_comment = None def _retreat(self, index): self._advance(index - self._index) @@ -661,6 +721,7 @@ class Parser: expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() + self._parse_query_modifiers(expression) return expression @@ -682,7 +743,11 @@ class Parser: ) def _parse_exists(self, not_=False): - return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) + return ( + self._match(TokenType.IF) + and (not not_ or self._match(TokenType.NOT)) + and self._match(TokenType.EXISTS) + ) def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) @@ -931,7 +996,9 @@ class Parser: return self.expression( exp.Delete, this=self._parse_table(schema=True), - using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)), + using=self._parse_csv( + lambda: self._match(TokenType.USING) and self._parse_table(schema=True) + ), where=self._parse_where(), ) @@ -983,11 +1050,13 @@ class Parser: return None def parse_values(): - k = self._parse_var() + key = self._parse_var() + value = None + if self._match(TokenType.EQ): - v = self._parse_string() - return (k, v) - return (k, None) + value = self._parse_string() + + return exp.Property(this=key, value=value) self._match_l_paren() values = self._parse_csv(parse_values) @@ -1019,6 +1088,8 @@ class Parser: self.raise_error(f"{this.key} does not support CTE") this = cte elif self._match(TokenType.SELECT): + comment = self._prev_comment + hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -1033,7 +1104,7 @@ class Parser: self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) + expressions = self._parse_csv(self._parse_expression) this = self.expression( exp.Select, @@ -1042,6 +1113,7 @@ class Parser: expressions=expressions, limit=limit, ) + this.comment = comment from_ = self._parse_from() if from_: this.set("from", from_) @@ -1072,8 +1144,10 @@ class Parser: while True: expressions.append(self._parse_cte()) - if not self._match(TokenType.COMMA): + if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH): break + else: + self._match(TokenType.WITH) return self.expression( exp.With, @@ -1111,11 +1185,7 @@ class Parser: if not alias and not columns: return None - return self.expression( - exp.TableAlias, - this=alias, - columns=columns, - ) + return self.expression(exp.TableAlias, this=alias, columns=columns) def _parse_subquery(self, this): return self.expression( @@ -1150,12 +1220,6 @@ class Parser: if expression: this.set(key, expression) - def _parse_annotation(self, expression): - if self._match(TokenType.ANNOTATION): - return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression) - - return expression - def _parse_hint(self): if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) @@ -1295,7 +1359,9 @@ class Parser: if not table: self.raise_error("Expected table name") - this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()) + this = self.expression( + exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() + ) if schema: return self._parse_schema(this=this) @@ -1500,7 +1566,9 @@ class Parser: if not skip_order_token and not self._match(TokenType.ORDER_BY): return this - return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) + return self.expression( + exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) + ) def _parse_sort(self, token_type, exp_class): if not self._match(token_type): @@ -1521,7 +1589,8 @@ class Parser: if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") + (asc and self.null_ordering == "nulls_are_small") + or (desc and self.null_ordering != "nulls_are_small") ) and self.null_ordering != "nulls_are_last" ): @@ -1606,6 +1675,9 @@ class Parser: def _parse_is(self, this): negate = self._match(TokenType.NOT) + if self._match(TokenType.DISTINCT_FROM): + klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ + return self.expression(klass, this=this, expression=self._parse_expression()) this = self.expression( exp.Is, this=this, @@ -1653,9 +1725,13 @@ class Parser: expression=self._parse_term(), ) elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) + this = self.expression( + exp.BitwiseLeftShift, this=this, expression=self._parse_term() + ) elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) + this = self.expression( + exp.BitwiseRightShift, this=this, expression=self._parse_term() + ) else: break @@ -1685,7 +1761,7 @@ class Parser: ) index = self._index - type_token = self._parse_types() + type_token = self._parse_types(check_func=True) this = self._parse_column() if type_token: @@ -1698,7 +1774,7 @@ class Parser: return this - def _parse_types(self): + def _parse_types(self, check_func=False): index = self._index if not self._match_set(self.TYPE_TOKENS): @@ -1708,10 +1784,13 @@ class Parser: nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token == TokenType.STRUCT expressions = None + maybe_func = False if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): return exp.DataType( - this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True + this=exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(type_token.value)], + nested=True, ) if self._match(TokenType.L_BRACKET): @@ -1731,6 +1810,7 @@ class Parser: return None self._match_r_paren() + maybe_func = True if nested and self._match(TokenType.LT): if is_struct: @@ -1741,25 +1821,46 @@ class Parser: if not self._match(TokenType.GT): self.raise_error("Expecting >") + value = None if type_token in self.TIMESTAMPS: - tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ - if tz: - return exp.DataType( + if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: + value = exp.DataType( this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions, ) - ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ - if ltz: - return exp.DataType( + elif ( + self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ + ): + value = exp.DataType( this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions, ) - self._match(TokenType.WITHOUT_TIME_ZONE) + elif self._match(TokenType.WITHOUT_TIME_ZONE): + value = exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) - return exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + maybe_func = maybe_func and value is None + + if value is None: + value = exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) + + if maybe_func and check_func: + index2 = self._index + peek = self._parse_string() + + if not peek: + self._retreat(index) + return None + + self._retreat(index2) + + if value: + return value return exp.DataType( this=exp.DataType.Type[type_token.value.upper()], @@ -1826,22 +1927,29 @@ class Parser: return exp.Literal.number(f"0.{self._prev.text}") if self._match(TokenType.L_PAREN): + comment = self._prev_comment query = self._parse_select() if query: expressions = [query] else: - expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) + expressions = self._parse_csv( + lambda: self._parse_alias(self._parse_conjunction(), explicit=True) + ) - this = list_get(expressions, 0) + this = seq_get(expressions, 0) self._parse_query_modifiers(this) self._match_r_paren() if isinstance(this, exp.Subqueryable): - return self._parse_set_operations(self._parse_subquery(this)) - if len(expressions) > 1: - return self.expression(exp.Tuple, expressions=expressions) - return self.expression(exp.Paren, this=this) + this = self._parse_set_operations(self._parse_subquery(this)) + elif len(expressions) > 1: + this = self.expression(exp.Tuple, expressions=expressions) + else: + this = self.expression(exp.Paren, this=this) + if comment: + this.comment = comment + return this return None @@ -1894,7 +2002,8 @@ class Parser: self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) - self._match_r_paren() + + self._match_r_paren(this) return self._parse_window(this) def _parse_user_defined_function(self): @@ -1920,6 +2029,18 @@ class Parser: return self.expression(exp.Identifier, this=token.text) + def _parse_session_parameter(self): + kind = None + this = self._parse_id_var() or self._parse_primary() + if self._match(TokenType.DOT): + kind = this.name + this = self._parse_var() or self._parse_primary() + return self.expression( + exp.SessionParameter, + this=this, + kind=kind, + ) + def _parse_udf_kwarg(self): this = self._parse_id_var() kind = self._parse_types() @@ -1938,27 +2059,24 @@ class Parser: else: expressions = [self._parse_id_var()] - if not self._match(TokenType.ARROW): - self._retreat(index) + if self._match_set(self.LAMBDAS): + return self.LAMBDAS[self._prev.token_type](self, expressions) - if self._match(TokenType.DISTINCT): - this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)) - else: - this = self._parse_conjunction() + self._retreat(index) - if self._match(TokenType.IGNORE_NULLS): - this = self.expression(exp.IgnoreNulls, this=this) - else: - self._match(TokenType.RESPECT_NULLS) + if self._match(TokenType.DISTINCT): + this = self.expression( + exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) + ) + else: + this = self._parse_conjunction() - return self._parse_alias(self._parse_limit(self._parse_order(this))) + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + else: + self._match(TokenType.RESPECT_NULLS) - conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions}) - return self.expression( - exp.Lambda, - this=conjunction, - expressions=expressions, - ) + return self._parse_alias(self._parse_limit(self._parse_order(this))) def _parse_schema(self, this=None): index = self._index @@ -1966,7 +2084,9 @@ class Parser: self._retreat(index) return this - args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) + args = self._parse_csv( + lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) + ) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -2104,6 +2224,7 @@ class Parser: if not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") + this.comment = self._prev_comment return self._parse_bracket(this) def _parse_case(self): @@ -2124,7 +2245,9 @@ class Parser: if not self._match(TokenType.END): self.raise_error("Expected END after CASE", self._prev) - return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) + return self._parse_window( + self.expression(exp.Case, this=expression, ifs=ifs, default=default) + ) def _parse_if(self): if self._match(TokenType.L_PAREN): @@ -2331,7 +2454,9 @@ class Parser: self._match(TokenType.BETWEEN) return { - "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) + "value": ( + self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text + ) or self._parse_bitwise(), "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } @@ -2348,7 +2473,7 @@ class Parser: this=this, expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), ) - self._match_r_paren() + self._match_r_paren(aliases) return aliases alias = self._parse_id_var(any_token) @@ -2365,28 +2490,29 @@ class Parser: return identifier if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: - return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) - - return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) + self._advance() + elif not self._match_set(tokens or self.ID_VAR_TOKENS): + return None + return exp.Identifier(this=self._prev.text, quoted=False) def _parse_string(self): if self._match(TokenType.STRING): - return exp.Literal.string(self._prev.text) + return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() def _parse_number(self): if self._match(TokenType.NUMBER): - return exp.Literal.number(self._prev.text) + return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) return self._parse_placeholder() def _parse_identifier(self): if self._match(TokenType.IDENTIFIER): - return exp.Identifier(this=self._prev.text, quoted=True) + return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() def _parse_var(self): if self._match(TokenType.VAR): - return exp.Var(this=self._prev.text) + return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() def _parse_var_or_string(self): @@ -2394,27 +2520,27 @@ class Parser: def _parse_null(self): if self._match(TokenType.NULL): - return exp.Null() + return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return None def _parse_boolean(self): if self._match(TokenType.TRUE): - return exp.Boolean(this=True) + return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): - return exp.Boolean(this=False) + return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) return None def _parse_star(self): if self._match(TokenType.STAR): - return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) + return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None def _parse_placeholder(self): if self._match(TokenType.PLACEHOLDER): - return exp.Placeholder() + return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): self._advance() - return exp.Placeholder(this=self._prev.text) + return self.expression(exp.Placeholder, this=self._prev.text) return None def _parse_except(self): @@ -2432,22 +2558,27 @@ class Parser: self._match_r_paren() return columns - def _parse_csv(self, parse): - parse_result = parse() + def _parse_csv(self, parse_method): + parse_result = parse_method() items = [parse_result] if parse_result is not None else [] while self._match(TokenType.COMMA): - parse_result = parse() + if parse_result and self._prev_comment is not None: + parse_result.comment = self._prev_comment + + parse_result = parse_method() if parse_result is not None: items.append(parse_result) return items - def _parse_tokens(self, parse, expressions): - this = parse() + def _parse_tokens(self, parse_method, expressions): + this = parse_method() while self._match_set(expressions): - this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) + this = self.expression( + expressions[self._prev.token_type], this=this, expression=parse_method() + ) return this @@ -2460,6 +2591,47 @@ class Parser: def _parse_select_or_expression(self): return self._parse_select() or self._parse_expression() + def _parse_use(self): + return self.expression(exp.Use, this=self._parse_id_var()) + + def _parse_show(self): + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) + if parser: + return parser(self) + self._advance() + return self.expression(exp.Show, this=self._prev.text.upper()) + + def _default_parse_set_item(self): + return self.expression( + exp.SetItem, + this=self._parse_statement(), + ) + + def _parse_set_item(self): + parser = self._find_parser(self.SET_PARSERS, self._set_trie) + return parser(self) if parser else self._default_parse_set_item() + + def _parse_set(self): + return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + + def _find_parser(self, parsers, trie): + index = self._index + this = [] + while True: + # The current token might be multiple words + curr = self._curr.text.upper() + key = curr.split(" ") + this.append(curr) + self._advance() + result, trie = in_trie(trie, key) + if result == 0: + break + if result == 2: + subparser = parsers[" ".join(this)] + return subparser + self._retreat(index) + return None + def _match(self, token_type): if not self._curr: return None @@ -2491,13 +2663,17 @@ class Parser: return None - def _match_l_paren(self): + def _match_l_paren(self, expression=None): if not self._match(TokenType.L_PAREN): self.raise_error("Expecting (") + if expression and self._prev_comment: + expression.comment = self._prev_comment - def _match_r_paren(self): + def _match_r_paren(self, expression=None): if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") + if expression and self._prev_comment: + expression.comment = self._prev_comment def _match_text(self, *texts): index = self._index diff --git a/sqlglot/planner.py b/sqlglot/planner.py index ea995d8..cd1de5e 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -72,7 +72,9 @@ class Step: if from_: from_ = from_.expressions if len(from_) > 1: - raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer") + raise UnsupportedError( + "Multi-from statements are unsupported. Run it through the optimizer" + ) step = Scan.from_expression(from_[0], ctes) else: @@ -102,7 +104,7 @@ class Step: continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" - operand.replace(exp.column(operands[operand], step.name, quoted=True)) + operand.replace(exp.column(operands[operand], quoted=True)) else: projections.append(e) @@ -117,9 +119,11 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name - aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) + aggregate.operands = tuple( + alias(operand, alias_) for operand, alias_ in operands.items() + ) aggregate.aggregations = aggregations - aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions] + aggregate.group = group.expressions aggregate.add_dependency(step) step = aggregate @@ -136,9 +140,6 @@ class Step: sort.key = order.expressions sort.add_dependency(step) step = sort - for k in sort.key + projections: - for column in k.find_all(exp.Column): - column.set("table", exp.to_identifier(step.name, quoted=True)) step.projections = projections @@ -203,7 +204,9 @@ class Scan(Step): alias_ = expression.alias if not alias_: - raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer") + raise UnsupportedError( + "Tables/Subqueries must be aliased. Run it through the optimizer" + ) if isinstance(expression, exp.Subquery): table = expression.this diff --git a/sqlglot/py.typed b/sqlglot/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/sqlglot/py.typed diff --git a/sqlglot/schema.py b/sqlglot/schema.py index c916330..fcf7291 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -1,44 +1,60 @@ +from __future__ import annotations + import abc +import typing as t from sqlglot import expressions as exp -from sqlglot.errors import OptimizeError +from sqlglot.errors import SchemaError from sqlglot.helper import csv_reader +from sqlglot.trie import in_trie, new_trie + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.types import StructType + + ColumnMapping = t.Union[t.Dict, str, StructType, t.List] + +TABLE_ARGS = ("this", "db", "catalog") class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def add_table(self, table, column_mapping=None): + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: """ - Register or update a table. Some implementing classes may require column information to also be provided + Register or update a table. Some implementing classes may require column information to also be provided. Args: - table (sqlglot.expressions.Table|str): Table expression instance or string representing the table - column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + table: table expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. """ @abc.abstractmethod - def column_names(self, table, only_visible=False): + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: """ Get the column names for a table. + Args: - table (sqlglot.expressions.Table): Table expression instance - only_visible (bool): Whether to include invisible columns + table: the `Table` expression instance. + only_visible: whether to include invisible columns. + Returns: - list[str]: list of column names + The list of column names. """ @abc.abstractmethod - def get_column_type(self, table, column): + def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: """ - Get the exp.DataType type of a column in the schema. + Get the :class:`sqlglot.exp.DataType` type of a column in the schema. Args: - table (sqlglot.expressions.Table): The source table. - column (sqlglot.expressions.Column): The target column. + table: the source table. + column: the target column. + Returns: - sqlglot.expressions.DataType.Type: The resulting column type. + The resulting column type. """ @@ -60,132 +76,179 @@ class MappingSchema(Schema): dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema=None, visible=None, dialect=None): + def __init__( + self, + schema: t.Optional[t.Dict] = None, + visible: t.Optional[t.Dict] = None, + dialect: t.Optional[str] = None, + ) -> None: self.schema = schema or {} - self.visible = visible + self.visible = visible or {} + self.schema_trie = self._build_trie(self.schema) self.dialect = dialect - self._type_mapping_cache = {} - self.supported_table_args = [] - self.forbidden_table_args = set() - if self.schema: - self._initialize_supported_args() + self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {} + self._supported_table_args: t.Tuple[str, ...] = tuple() @classmethod - def from_mapping_schema(cls, mapping_schema): + def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: + return MappingSchema( + schema=mapping_schema.schema, + visible=mapping_schema.visible, + dialect=mapping_schema.dialect, + ) + + def copy(self, **kwargs) -> MappingSchema: return MappingSchema( - schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect + **{ # type: ignore + "schema": self.schema.copy(), + "visible": self.visible.copy(), + "dialect": self.dialect, + **kwargs, + } ) - def copy(self, **kwargs): - return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) + @property + def supported_table_args(self): + if not self._supported_table_args and self.schema: + depth = _dict_depth(self.schema) - def add_table(self, table, column_mapping=None): + if not depth or depth == 1: # {} + self._supported_table_args = tuple() + elif 2 <= depth <= 4: + self._supported_table_args = TABLE_ARGS[: depth - 1] + else: + raise SchemaError(f"Invalid schema shape. Depth: {depth}") + + return self._supported_table_args + + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. Args: - table (sqlglot.expressions.Table|str): Table expression instance or string representing the table - column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. """ - table = exp.to_table(table) - self._validate_table(table) + table_ = self._ensure_table(table) column_mapping = ensure_column_mapping(column_mapping) - table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] - existing_column_mapping = _nested_get( - self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False - ) - if existing_column_mapping and not column_mapping: + schema = self.find_schema(table_, raise_on_missing=False) + + if schema and not column_mapping: return + _nested_set( self.schema, - [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], + list(reversed(self.table_parts(table_))), column_mapping, ) - self._initialize_supported_args() + self.schema_trie = self._build_trie(self.schema) - def _get_table_args_from_table(self, table): - if table.args.get("catalog") is not None: - return "catalog", "db", "this" - if table.args.get("db") is not None: - return "db", "this" - return ("this",) + def _ensure_table(self, table: exp.Table | str) -> exp.Table: + table_ = exp.to_table(table) - def _validate_table(self, table): - if not self.supported_table_args and isinstance(table, exp.Table): - return - for forbidden in self.forbidden_table_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - for expected in self.supported_table_args: - if not table.text(expected): - raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ") + if not table_: + raise SchemaError(f"Not a valid table '{table}'") + + return table_ + + def table_parts(self, table: exp.Table) -> t.List[str]: + return [table.text(part) for part in TABLE_ARGS if table.text(part)] + + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + table_ = self._ensure_table(table) - def column_names(self, table, only_visible=False): - table = exp.to_table(table) - if not isinstance(table.this, exp.Identifier): - return fs_get(table) + if not isinstance(table_.this, exp.Identifier): + return fs_get(table) # type: ignore - args = tuple(table.text(p) for p in self.supported_table_args) + schema = self.find_schema(table_) - for forbidden in self.forbidden_table_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + if schema is None: + raise SchemaError(f"Could not find table schema {table}") - columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) if not only_visible or not self.visible: - return columns + return list(schema) - visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) - return [col for col in columns if col in visible] + visible = self._nested_get(self.table_parts(table_), self.visible) + return [col for col in schema if col in visible] # type: ignore - def get_column_type(self, table, column): - try: - schema_type = self.schema.get(table.name, {}).get(column.name).upper() + def find_schema( + self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True + ) -> t.Optional[t.Dict[str, str]]: + parts = self.table_parts(table)[0 : len(self.supported_table_args)] + value, trie = in_trie(self.schema_trie if trie is None else trie, parts) + + if value == 0: + if raise_on_missing: + raise SchemaError(f"Cannot find schema for {table}.") + else: + return None + elif value == 1: + possibilities = flatten_schema(trie) + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + message = ", ".join(".".join(parts) for parts in possibilities) + if raise_on_missing: + raise SchemaError(f"Ambiguous schema for {table}: {message}.") + return None + + return self._nested_get(parts, raise_on_missing=raise_on_missing) + + def get_column_type( + self, table: exp.Table | str, column: exp.Column | str + ) -> exp.DataType.Type: + column_name = column if isinstance(column, str) else column.name + table_ = exp.to_table(table) + if table_: + table_schema = self.find_schema(table_) + schema_type = table_schema.get(column_name).upper() # type: ignore return self._convert_type(schema_type) - except: - raise OptimizeError(f"Failed to get type for column {column.sql()}") + raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type): + def _convert_type(self, schema_type: str) -> exp.DataType.Type: """ - Convert a type represented as a string to the corresponding exp.DataType.Type object. + Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. + Args: - schema_type (str): The type we want to convert. + schema_type: the type we want to convert. + Returns: - sqlglot.expressions.DataType.Type: The resulting expression type. + The resulting expression type. """ if schema_type not in self._type_mapping_cache: try: - self._type_mapping_cache[schema_type] = exp.maybe_parse( - schema_type, into=exp.DataType, dialect=self.dialect - ).this + expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) + if expression is None: + raise ValueError(f"Could not parse {schema_type}") + self._type_mapping_cache[schema_type] = expression.this except AttributeError: - raise OptimizeError(f"Failed to convert type {schema_type}") + raise SchemaError(f"Failed to convert type {schema_type}") return self._type_mapping_cache[schema_type] - def _initialize_supported_args(self): - if not self.supported_table_args: - depth = _dict_depth(self.schema) - - all_args = ["this", "db", "catalog"] - if not depth or depth == 1: # {} - self.supported_table_args = [] - elif 2 <= depth <= 4: - self.supported_table_args = tuple(reversed(all_args[: depth - 1])) - else: - raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + def _build_trie(self, schema: t.Dict): + return new_trie(tuple(reversed(t)) for t in flatten_schema(schema)) - self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args) + def _nested_get( + self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True + ) -> t.Optional[t.Any]: + return _nested_get( + d or self.schema, + *zip(self.supported_table_args, reversed(parts)), + raise_on_missing=raise_on_missing, + ) -def ensure_schema(schema): +def ensure_schema(schema: t.Any) -> Schema: if isinstance(schema, Schema): return schema return MappingSchema(schema) -def ensure_column_mapping(mapping): +def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): if isinstance(mapping, dict): return mapping elif isinstance(mapping, str): @@ -196,7 +259,7 @@ def ensure_column_mapping(mapping): } # Check if mapping looks like a DataFrame StructType elif hasattr(mapping, "simpleString"): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore elif isinstance(mapping, list): return {x.strip(): None for x in mapping} elif mapping is None: @@ -204,7 +267,20 @@ def ensure_column_mapping(mapping): raise ValueError(f"Invalid mapping provided: {type(mapping)}") -def fs_get(table): +def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]: + tables = [] + keys = keys or [] + depth = _dict_depth(schema) + + for k, v in schema.items(): + if depth >= 3: + tables.extend(flatten_schema(v, keys + [k])) + elif depth == 2: + tables.append(keys + [k]) + return tables + + +def fs_get(table: exp.Table) -> t.List[str]: name = table.this.name if name.upper() == "READ_CSV": @@ -214,21 +290,23 @@ def fs_get(table): raise ValueError(f"Cannot read schema for {table}") -def _nested_get(d, *path, raise_on_missing=True): +def _nested_get( + d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True +) -> t.Optional[t.Any]: """ Get a value for a nested dictionary. Args: - d (dict): dictionary - *path (tuple[str, str]): tuples of (name, key) + d: the dictionary to search. + *path: tuples of (name, key), where: `key` is the key in the dictionary to get. `name` is a string to use in the error if `key` isn't found. Returns: - The value or None if it doesn't exist + The value or None if it doesn't exist. """ for name, key in path: - d = d.get(key) + d = d.get(key) # type: ignore if d is None: if raise_on_missing: name = "table" if name == "this" else name @@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True): return d -def _nested_set(d, keys, value): +def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary - Ex: + Example: >>> _nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}} + >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} - d (dict): dictionary - keys (Iterable[str]): ordered iterable of keys that makeup path to value - value (Any): The value to set in the dictionary for the given key path + Args: + d: dictionary to update. + keys: the keys that makeup the path to `value`. + value: the value to set in the dictionary for the given key path. + + Returns: + The (possibly) updated dictionary. """ if not keys: - return + return d + if len(keys) == 1: d[keys[0]] = value - return + return d + subd = d for key in keys[:-1]: if key not in subd: subd = subd.setdefault(key, {}) else: subd = subd[key] + subd[keys[-1]] = value return d -def _dict_depth(d): +def _dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. diff --git a/sqlglot/time.py b/sqlglot/time.py index 729b50d..97726b3 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -1,9 +1,13 @@ -# the generic time format is based on python time.strftime +import typing as t + +# The generic time format is based on python time.strftime. # https://docs.python.org/3/library/time.html#time.strftime from sqlglot.trie import in_trie, new_trie -def format_time(string, mapping, trie=None): +def format_time( + string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None +) -> t.Optional[str]: """ Converts a time string given a mapping. @@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None): >>> format_time("%Y", {"%Y": "YYYY"}) 'YYYY' - mapping: Dictionary of time format to target time format - trie: Optional trie, can be passed in for performance + Args: + mapping: dictionary of time format to target time format. + trie: optional trie, can be passed in for performance. + + Returns: + The converted time string. """ if not string: return None + start = 0 end = 1 size = len(string) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 766c01a..95d84d6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from enum import auto from sqlglot.helper import AutoName @@ -27,6 +30,7 @@ class TokenType(AutoName): NOT = auto() EQ = auto() NEQ = auto() + NULLSAFE_EQ = auto() AND = auto() OR = auto() AMP = auto() @@ -36,12 +40,14 @@ class TokenType(AutoName): TILDA = auto() ARROW = auto() DARROW = auto() + FARROW = auto() + HASH = auto() HASH_ARROW = auto() DHASH_ARROW = auto() LR_ARROW = auto() - ANNOTATION = auto() DOLLAR = auto() PARAMETER = auto() + SESSION_PARAMETER = auto() SPACE = auto() BREAK = auto() @@ -73,7 +79,7 @@ class TokenType(AutoName): NVARCHAR = auto() TEXT = auto() BINARY = auto() - BYTEA = auto() + VARBINARY = auto() JSON = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() @@ -142,6 +148,7 @@ class TokenType(AutoName): DESCRIBE = auto() DETERMINISTIC = auto() DISTINCT = auto() + DISTINCT_FROM = auto() DISTRIBUTE_BY = auto() DIV = auto() DROP = auto() @@ -238,6 +245,7 @@ class TokenType(AutoName): RETURNS = auto() RIGHT = auto() RLIKE = auto() + ROLLBACK = auto() ROLLUP = auto() ROW = auto() ROWS = auto() @@ -287,37 +295,49 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col") + __slots__ = ("token_type", "text", "line", "col", "comment") @classmethod - def number(cls, number): + def number(cls, number: int) -> Token: + """Returns a NUMBER token with `number` as its text.""" return cls(TokenType.NUMBER, str(number)) @classmethod - def string(cls, string): + def string(cls, string: str) -> Token: + """Returns a STRING token with `string` as its text.""" return cls(TokenType.STRING, string) @classmethod - def identifier(cls, identifier): + def identifier(cls, identifier: str) -> Token: + """Returns an IDENTIFIER token with `identifier` as its text.""" return cls(TokenType.IDENTIFIER, identifier) @classmethod - def var(cls, var): + def var(cls, var: str) -> Token: + """Returns an VAR token with `var` as its text.""" return cls(TokenType.VAR, var) - def __init__(self, token_type, text, line=1, col=1): + def __init__( + self, + token_type: TokenType, + text: str, + line: int = 1, + col: int = 1, + comment: t.Optional[str] = None, + ) -> None: self.token_type = token_type self.text = text self.line = line self.col = max(col - len(text), 1) + self.comment = comment - def __repr__(self): + def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) return f"<Token {attributes}>" class _Tokenizer(type): - def __new__(cls, clsname, bases, attrs): + def __new__(cls, clsname, bases, attrs): # type: ignore klass = super().__new__(cls, clsname, bases, attrs) klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) @@ -325,27 +345,29 @@ class _Tokenizer(type): klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) + klass._ESCAPES = set(klass.ESCAPES) klass._COMMENTS = dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) + for comment in klass.COMMENTS ) klass.KEYWORD_TRIE = new_trie( key.upper() - for key, value in { + for key in { **klass.KEYWORDS, **{comment: TokenType.COMMENT for comment in klass._COMMENTS}, **{quote: TokenType.QUOTE for quote in klass._QUOTES}, **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS}, - }.items() + } if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) return klass @staticmethod - def _delimeter_list_to_dict(list): + def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) @@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer): "*": TokenType.STAR, "~": TokenType.TILDA, "?": TokenType.PLACEHOLDER, - "#": TokenType.ANNOTATION, "@": TokenType.PARAMETER, # used for breaking a var like x'y' but nothing else # the token type doesn't matter "'": TokenType.QUOTE, "`": TokenType.IDENTIFIER, '"': TokenType.IDENTIFIER, + "#": TokenType.HASH, } - QUOTES = ["'"] + QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] - BIT_STRINGS = [] + BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] - HEX_STRINGS = [] + HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] - BYTE_STRINGS = [] + BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - IDENTIFIERS = ['"'] + IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - ESCAPE = "'" + ESCAPES = ["'"] KEYWORDS = { "/*+": TokenType.HINT, @@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer): "<=": TokenType.LTE, "<>": TokenType.NEQ, "!=": TokenType.NEQ, + "<=>": TokenType.NULLSAFE_EQ, "->": TokenType.ARROW, "->>": TokenType.DARROW, + "=>": TokenType.FARROW, "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, @@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer): "DESCRIBE": TokenType.DESCRIBE, "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, + "DISTINCT FROM": TokenType.DISTINCT_FROM, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DIV": TokenType.DIV, "DROP": TokenType.DROP, @@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer): "RETURNS": TokenType.RETURNS, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, + "ROLLBACK": TokenType.ROLLBACK, "ROLLUP": TokenType.ROLLUP, "ROW": TokenType.ROW, "ROWS": TokenType.ROWS, @@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer): "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, "BINARY": TokenType.BINARY, - "BLOB": TokenType.BINARY, - "BYTEA": TokenType.BINARY, + "BLOB": TokenType.VARBINARY, + "BYTEA": TokenType.VARBINARY, + "VARBINARY": TokenType.VARBINARY, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, @@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.SET, TokenType.SHOW, TokenType.TRUNCATE, - TokenType.USE, TokenType.VACUUM, + TokenType.ROLLBACK, } # handle numeric literals like in hive (3L = BIGINT) - NUMERIC_LITERALS = {} - ENCODE = None + NUMERIC_LITERALS: t.Dict[str, str] = {} + ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/")] KEYWORD_TRIE = None # autofilled @@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer): "_current", "_line", "_col", + "_comment", "_char", "_end", "_peek", + "_prev_token_line", + "_prev_token_comment", "_prev_token_type", + "_replace_backslash", ) - def __init__(self): - """ - Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token` - """ + def __init__(self) -> None: + self._replace_backslash = "\\" in self._ESCAPES # type: ignore self.reset() - def reset(self): + def reset(self) -> None: self.sql = "" self.size = 0 - self.tokens = [] + self.tokens: t.List[Token] = [] self._start = 0 self._current = 0 self._line = 1 self._col = 1 + self._comment = None self._char = None self._end = None self._peek = None + self._prev_token_line = -1 + self._prev_token_comment = None self._prev_token_type = None - def tokenize(self, sql): + def tokenize(self, sql: str) -> t.List[Token]: + """Returns a list of tokens corresponding to the SQL string `sql`.""" self.reset() self.sql = sql self.size = len(sql) @@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer): if not self._char: break - white_space = self.WHITE_SPACE.get(self._char) - identifier_end = self._IDENTIFIERS.get(self._char) + white_space = self.WHITE_SPACE.get(self._char) # type: ignore + identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore if white_space: if white_space == TokenType.BREAK: self._col = 1 self._line += 1 - elif self._char.isdigit(): + elif self._char.isdigit(): # type:ignore self._scan_number() elif identifier_end: self._scan_identifier(identifier_end) @@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer): self._scan_keywords() return self.tokens - def _chars(self, size): + def _chars(self, size: int) -> str: if size == 1: - return self._char + return self._char # type: ignore start = self._current - 1 end = start + size if end <= self.size: return self.sql[start:end] return "" - def _advance(self, i=1): + def _advance(self, i: int = 1) -> None: self._col += i self._current += i - self._end = self._current >= self.size - self._char = self.sql[self._current - 1] - self._peek = self.sql[self._current] if self._current < self.size else "" + self._end = self._current >= self.size # type: ignore + self._char = self.sql[self._current - 1] # type: ignore + self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore @property - def _text(self): + def _text(self) -> str: return self.sql[self._start : self._current] - def _add(self, token_type, text=None): - self._prev_token_type = token_type - self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col)) + def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: + self._prev_token_line = self._line + self._prev_token_comment = self._comment + self._prev_token_type = token_type # type: ignore + self.tokens.append( + Token( + token_type, + self._text if text is None else text, + self._line, + self._col, + self._comment, + ) + ) + self._comment = None - if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): + if token_type in self.COMMANDS and ( + len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON + ): self._start = self._current while not self._end and self._peek != ";": self._advance() if self._start < self._current: self._add(TokenType.STRING) - def _scan_keywords(self): + def _scan_keywords(self) -> None: size = 0 word = None chars = self._text @@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer): if skip: result = 1 else: - result, trie = in_trie(trie, char.upper()) + result, trie = in_trie(trie, char.upper()) # type: ignore if result == 0: break @@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer): else: skip = True else: - chars = None + chars = None # type: ignore if not word: if self._char in self.SINGLE_TOKENS: - token = self.SINGLE_TOKENS[self._char] - if token == TokenType.ANNOTATION: - self._scan_annotation() - return - self._add(token) + self._add(self.SINGLE_TOKENS[self._char]) # type: ignore return self._scan_var() return @@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer): self._advance(size - 1) self._add(self.KEYWORDS[word.upper()]) - def _scan_comment(self, comment_start): - if comment_start not in self._COMMENTS: + def _scan_comment(self, comment_start: str) -> bool: + if comment_start not in self._COMMENTS: # type: ignore return False - comment_end = self._COMMENTS[comment_start] + comment_start_line = self._line + comment_start_size = len(comment_start) + comment_end = self._COMMENTS[comment_start] # type: ignore if comment_end: comment_end_size = len(comment_end) while not self._end and self._chars(comment_end_size) != comment_end: self._advance() + + self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore self._advance(comment_end_size - 1) else: - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore self._advance() - return True + self._comment = self._text[comment_start_size:] # type: ignore - def _scan_annotation(self): - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",": - self._advance() - self._add(TokenType.ANNOTATION, self._text[1:]) + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both + # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one. - def _scan_number(self): + if comment_start_line == self._prev_token_line: + if self._prev_token_comment is None: + self.tokens[-1].comment = self._comment + + self._comment = None + + return True + + def _scan_number(self) -> None: if self._char == "0": - peek = self._peek.upper() + peek = self._peek.upper() # type: ignore if peek == "B": return self._scan_bits() elif peek == "X": @@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer): scientific = 0 while True: - if self._peek.isdigit(): + if self._peek.isdigit(): # type: ignore self._advance() elif self._peek == "." and not decimal: decimal = True @@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() - elif self._peek.upper() == "E" and not scientific: + elif self._peek.upper() == "E" and not scientific: # type: ignore scientific += 1 self._advance() - elif self._peek.isalpha(): + elif self._peek.isalpha(): # type: ignore self._add(TokenType.NUMBER) literal = [] - while self._peek.isalpha(): - literal.append(self._peek.upper()) + while self._peek.isalpha(): # type: ignore + literal.append(self._peek.upper()) # type: ignore self._advance() - literal = "".join(literal) - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) + literal = "".join(literal) # type: ignore + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore if token_type: self._add(TokenType.DCOLON, "::") - return self._add(token_type, literal) + return self._add(token_type, literal) # type: ignore return self._advance(-len(literal)) else: return self._add(TokenType.NUMBER) - def _scan_bits(self): + def _scan_bits(self) -> None: self._advance() value = self._extract_value() try: @@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer): except ValueError: self._add(TokenType.IDENTIFIER) - def _scan_hex(self): + def _scan_hex(self) -> None: self._advance() value = self._extract_value() try: @@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer): except ValueError: self._add(TokenType.IDENTIFIER) - def _extract_value(self): + def _extract_value(self) -> str: while True: - char = self._peek.strip() + char = self._peek.strip() # type: ignore if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer): return self._text - def _scan_string(self, quote): - quote_end = self._QUOTES.get(quote) + def _scan_string(self, quote: str) -> bool: + quote_end = self._QUOTES.get(quote) # type: ignore if quote_end is None: return False self._advance(len(quote)) text = self._extract_string(quote_end) - - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text - text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore + text = text.replace("\\\\", "\\") if self._replace_backslash else text self._add(TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. - def _scan_formatted_string(self, string_start): - if string_start in self._HEX_STRINGS: - delimiters = self._HEX_STRINGS + def _scan_formatted_string(self, string_start: str) -> bool: + if string_start in self._HEX_STRINGS: # type: ignore + delimiters = self._HEX_STRINGS # type: ignore token_type = TokenType.HEX_STRING base = 16 - elif string_start in self._BIT_STRINGS: - delimiters = self._BIT_STRINGS + elif string_start in self._BIT_STRINGS: # type: ignore + delimiters = self._BIT_STRINGS # type: ignore token_type = TokenType.BIT_STRING base = 2 - elif string_start in self._BYTE_STRINGS: - delimiters = self._BYTE_STRINGS + elif string_start in self._BYTE_STRINGS: # type: ignore + delimiters = self._BYTE_STRINGS # type: ignore token_type = TokenType.BYTE_STRING base = None else: @@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer): try: self._add(token_type, f"{int(text, base)}") except: - raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + raise RuntimeError( + f"Numeric string contains invalid characters from {self._line}:{self._start}" + ) return True - def _scan_identifier(self, identifier_end): + def _scan_identifier(self, identifier_end: str) -> None: while self._peek != identifier_end: if self._end: raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") @@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() self._add(TokenType.IDENTIFIER, self._text[1:-1]) - def _scan_var(self): + def _scan_var(self) -> None: while True: - char = self._peek.strip() + char = self._peek.strip() # type: ignore if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer): else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) ) - def _extract_string(self, delimiter): + def _extract_string(self, delimiter: str) -> str: text = "" delim_size = len(delimiter) while True: - if self._char == self.ESCAPE and self._peek == delimiter: + if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore text += delimiter self._advance(2) else: @@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") - text += self._char + text += self._char # type: ignore self._advance() return text diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 014ae00..412b881 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -1,7 +1,14 @@ +from __future__ import annotations + +import typing as t + +if t.TYPE_CHECKING: + from sqlglot.generator import Generator + from sqlglot import expressions as exp -def unalias_group(expression): +def unalias_group(expression: exp.Expression) -> exp.Expression: """ Replace references to select aliases in GROUP BY clauses. @@ -9,6 +16,12 @@ def unalias_group(expression): >>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1' + + Args: + expression: the expression that will be transformed. + + Returns: + The transformed expression. """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { @@ -30,19 +43,20 @@ def unalias_group(expression): return expression -def preprocess(transforms, to_sql): +def preprocess( + transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], + to_sql: t.Callable[[Generator, exp.Expression], str], +) -> t.Callable[[Generator, exp.Expression], str]: """ - Create a new transform function that can be used a value in `Generator.TRANSFORMS` - to convert expressions to SQL. + Creates a new transform by chaining a sequence of transformations and converts the resulting + expression to SQL, using an appropriate `Generator.TRANSFORMS` function. Args: - transforms (list[(exp.Expression) -> exp.Expression]): - Sequence of transform functions. These will be called in order. - to_sql ((sqlglot.generator.Generator, exp.Expression) -> str): - Final transform that converts the resulting expression to a SQL string. + transforms: sequence of transform functions. These will be called in order. + to_sql: final transform that converts the resulting expression to a SQL string. + Returns: - (sqlglot.generator.Generator, exp.Expression) -> str: - Function that can be used as a generator transform. + Function that can be used as a generator transform. """ def _to_sql(self, expression): @@ -54,12 +68,10 @@ def preprocess(transforms, to_sql): return _to_sql -def delegate(attr): +def delegate(attr: str) -> t.Callable: """ - Create a new method that delegates to `attr`. - - This is useful for creating `Generator.TRANSFORMS` functions that delegate - to existing generator methods. + Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS` + functions that delegate to existing generator methods. """ def _transform(self, *args, **kwargs): diff --git a/sqlglot/trie.py b/sqlglot/trie.py index a234107..fa2aaf1 100644 --- a/sqlglot/trie.py +++ b/sqlglot/trie.py @@ -1,5 +1,26 @@ -def new_trie(keywords): - trie = {} +import typing as t + +key = t.Sequence[t.Hashable] + + +def new_trie(keywords: t.Iterable[key]) -> t.Dict: + """ + Creates a new trie out of a collection of keywords. + + The trie is represented as a sequence of nested dictionaries keyed by either single character + strings, or by 0, which is used to designate that a keyword is in the trie. + + Example: + >>> new_trie(["bla", "foo", "blab"]) + {'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}} + + Args: + keywords: the keywords to create the trie from. + + Returns: + The trie corresponding to `keywords`. + """ + trie: t.Dict = {} for key in keywords: current = trie @@ -11,7 +32,28 @@ def new_trie(keywords): return trie -def in_trie(trie, key): +def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]: + """ + Checks whether a key is in a trie. + + Examples: + >>> in_trie(new_trie(["cat"]), "bob") + (0, {'c': {'a': {'t': {0: True}}}}) + + >>> in_trie(new_trie(["cat"]), "ca") + (1, {'t': {0: True}}) + + >>> in_trie(new_trie(["cat"]), "cat") + (2, {0: True}) + + Args: + trie: the trie to be searched. + key: the target key. + + Returns: + A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value` + is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). + """ if not key: return (0, trie) diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 4a89c78..16f8922 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,9 +1,9 @@ -import sys import typing as t import unittest import warnings import sqlglot +from sqlglot.helper import PYTHON_VERSION from tests.helpers import SKIP_INTEGRATION if t.TYPE_CHECKING: @@ -11,7 +11,8 @@ if t.TYPE_CHECKING: @unittest.skipIf( - SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" + SKIP_INTEGRATION or PYTHON_VERSION > (3, 10), + "Skipping Integration Tests since `SKIP_INTEGRATION` is set", ) class DataFrameValidator(unittest.TestCase): spark = None @@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase): # This is for test `test_branching_root_dataframes` config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) - cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark = ( + SparkSession.builder.master("local[*]") + .appName("Unit-tests") + .config(conf=config) + .getOrCreate() + ) cls.spark.sparkContext.setLogLevel("ERROR") cls.sqlglot = SqlglotSparkSession() cls.spark_employee_schema = types.StructType( @@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "employee_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), @@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) - cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee = cls.spark.createDataFrame( + data=employee_data, schema=cls.spark_employee_schema + ) + cls.dfs_employee = cls.sqlglot.createDataFrame( + data=employee_data, schema=cls.sqlglot_employee_schema + ) cls.df_employee.createOrReplaceTempView("employee") cls.spark_store_schema = types.StructType( @@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase): [ sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), ] ) @@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase): (2, "Arrow", 2, 2000), ] cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) - cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame( + data=store_data, schema=cls.sqlglot_store_schema + ) cls.df_store.createOrReplaceTempView("store") cls.spark_district_schema = types.StructType( @@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), + sqlglotSparkTypes.StructField( + "district_name", sqlglotSparkTypes.StringType(), False + ), + sqlglotSparkTypes.StructField( + "manager_name", sqlglotSparkTypes.StringType(), False + ), ] ) district_data = [ (1, "Temple", "Dogen"), (2, "Lighthouse", "Jacob"), ] - cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) - cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district = cls.spark.createDataFrame( + data=district_data, schema=cls.spark_district_schema + ) + cls.dfs_district = cls.sqlglot.createDataFrame( + data=district_data, schema=cls.sqlglot_district_schema + ) cls.df_district.createOrReplaceTempView("district") sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) sqlglot.schema.add_table("store", cls.sqlglot_store_schema) diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index c740bec..19e3b89 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator): def test_alias_with_select(self): df_employee = self.df_spark_employee.alias("df_employee").select( - self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + self.df_spark_employee["employee_id"], + F.col("df_employee.fname"), + self.df_spark_employee.lname, ) dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( - self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + self.df_sqlglot_employee["employee_id"], + SF.col("dfs_employee.fname"), + self.df_sqlglot_employee.lname, ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_case_when_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ) .when(F.col("age") < F.lit(40), "less than 40") .otherwise("greater than 60") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ) .when(SF.col("age") < SF.lit(40), "less than 40") .otherwise("greater than 60") ) @@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator): def test_case_when_no_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( - F.col("age") < F.lit(40), "less than 40" - ) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ).when(F.col("age") < F.lit(40), "less than 40") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( - SF.col("age") < SF.lit(40), "less than 40" - ) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ).when(SF.col("age") < SF.lit(40), "less than 40") ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_and(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) ) @@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_or(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) ) @@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator): dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] % F.lit(5) == F.lit(0) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] + F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] - F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) df_employee = self.df_spark_employee.where( self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) ) dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + self.df_sqlglot_employee["age"] * SF.lit(0.5) + == self.df_sqlglot_employee["age"] / SF.lit(2) ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_join_inner(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="inner" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="inner" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_no_select(self): - df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( - self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + df_joined = self.df_spark_employee.select( + F.col("store_id"), F.col("fname"), F.col("lname") + ).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), + on=["store_id"], + how="inner", ) - dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + dfs_joined = self.df_sqlglot_employee.select( + SF.col("store_id"), SF.col("fname"), SF.col("lname") + ).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), + on=["store_id"], + how="inner", ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_equality_single(self): df_joined = self.df_spark_employee.join( - self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + self.df_spark_store, + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], @@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store["num_sales"], ) dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + self.df_sqlglot_store, + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], @@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_full_outer(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="full_outer" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="full_outer" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator): def test_triple_join(self): df = ( - self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + self.df_employee.join( + self.df_store, on=self.df_employee.employee_id == self.df_store.store_id + ) .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) .select( self.df_employee.employee_id, @@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) dfs = ( - self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + self.dfs_employee.join( + self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id + ) .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) .select( self.dfs_employee.employee_id, @@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_join_select_and_select_start(self): - df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( - self.df_spark_store, "store_id", "inner" - ) + df = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ).join(self.df_spark_store, "store_id", "inner") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( - self.df_sqlglot_store, "store_id", "inner" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ).join(self.df_sqlglot_store, "store_id", "inner") self.compare_spark_with_sqlglot(df, dfs) @@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator): dfs_unioned = ( self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) - .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + .unionAll( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) ) self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) def test_union_by_name(self): - df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + df = self.df_spark_employee.select( + F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( self.df_spark_store.select( F.col("store_name").alias("lname"), F.col("store_id").alias("employee_id"), @@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) - dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + dfs = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( self.df_sqlglot_store.select( SF.col("store_name").alias("lname"), SF.col("store_id").alias("employee_id"), @@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_order_by_default(self): - df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales")) + .orderBy(F.col("district_id")) + ) dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales")) + .orderBy(SF.col("district_id")) ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + .orderBy( + F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + .orderBy( + SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last() + ) ) self.compare_spark_with_sqlglot(df, dfs) @@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + .orderBy( + F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + .orderBy( + SF.when( + SF.col("district_id") == SF.lit(1), SF.col("district_id") + ).desc_nulls_first() + ) ) self.compare_spark_with_sqlglot(df, dfs) def test_intersect(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersect(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_intersect_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersectAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_except_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.exceptAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) @@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_drop_na_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(how="any", thresh=2) dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(how="any", thresh=2) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(thresh=1, subset="the_age") dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(thresh=1, subset="the_age") self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_dropna_na_function(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.drop() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_fillna_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).fillna(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_fillna_na_func(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.fill(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_replace_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( to_replace=37, value=100 @@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_replace_mapping(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + {37: 100} + ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + {37: 100} + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator): to_replace=37, value=100 ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( - to_replace=37, value=100 - ) + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col") + ).na.replace(to_replace=37, value=100) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator): "first_name", "first_name_again" ) - dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( - "first_name", "first_name_again" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname").alias("first_name") + ).withColumnRenamed("first_name", "first_name_again") self.compare_spark_with_sqlglot(df, dfs) def test_drop_column_single(self): df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop( + "age" + ) self.compare_spark_with_sqlglot(df, dfs) @@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator): df_sqlglot_employee_cols = self.df_sqlglot_employee.select( SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") ) - df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + df_sqlglot_store_cols = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("store_name") + ) dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( df_sqlglot_employee_cols.age, ) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py index ff1477b..ec50034 100644 --- a/tests/dataframe/integration/test_session.py +++ b/tests/dataframe/integration/test_session.py @@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator): ON e.store_id = s.store_id """ - df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) - dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) + df = ( + self.spark.sql(query) + .groupBy(F.col("store_id")) + .agg(F.countDistinct(F.col("employee_id"))) + ) + dfs = ( + self.sqlglot.sql(query) + .groupBy(SF.col("store_id")) + .agg(SF.countDistinct(SF.col("employee_id"))) + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index fc56553..32ff8f2 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + self.df_employee = self.spark.createDataFrame( + data=employee_data, schema=self.employee_schema + ) - def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + def compare_sql( + self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False + ): actual_sqls = df.sql(pretty=pretty) - expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + expected_statements = ( + [expected_statements] if isinstance(expected_statements, str) else expected_statements + ) self.assertEqual(len(expected_statements), len(actual_sqls)) for expected, actual in zip(expected_statements, actual_sqls): self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index 977971e..da18502 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase): def test_and(self): self.assertEqual( - "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() + "cola = colb AND colc = cold", + ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(), ) def test_or(self): self.assertEqual( - "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() + "cola = colb OR colc = cold", + ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(), ) def test_mod(self): @@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase): def test_when_otherwise(self): self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) - self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql() + ) self.assertEqual( "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), @@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase): self.assertEqual( "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", - F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), + F.col("cola") + .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)) + .sql(), ) def test_over(self): diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py index c222cac..e36667b 100644 --- a/tests/dataframe/unit/test_dataframe.py +++ b/tests/dataframe/unit/test_dataframe.py @@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator): self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) def test_columns(self): - self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns) + self.assertEqual( + ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns + ) def test_cache(self): df = self.df_employee.select("fname").cache() diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index eadbb93..8e5e5cd 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase): col = SF.window(SF.col("cola"), "10 minutes") self.assertEqual("WINDOW(cola, '10 minutes')", col.sql()) col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds") - self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()) + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql() + ) col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds") - self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()) + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql() + ) col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds") self.assertEqual( - "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql() + "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", + col_no_slide.sql(), ) def test_session_window(self): @@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase): def test_from_json(self): col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + self.assertEqual( + "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql() + ) col_no_option = SF.from_json("cola", "cola INT") self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql()) @@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase): def test_schema_of_json(self): col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) col_no_option = SF.schema_of_json("cola") @@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase): col = SF.array_sort(SF.col("cola")) self.assertEqual("ARRAY_SORT(cola)", col.sql()) col_comparator = SF.array_sort( - "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) + "cola", + lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise( + SF.length(y) - SF.length(x) + ), ) self.assertEqual( "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", @@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase): def test_from_csv(self): col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + self.assertEqual( + "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql() + ) col_no_option = SF.from_csv("cola", "cola INT") self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql()) @@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) - self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) + self.assertEqual( + "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql() + ) def test_exists(self): col_str = SF.exists("cola", lambda x: x % 2 == 0) @@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) - col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) + col_custom_names = SF.filter( + "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count) + ) self.assertEqual( - "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() + "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", + col_custom_names.sql(), ) def test_zip_with(self): @@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase): col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) - self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) + self.assertEqual( + "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql() + ) def test_transform_keys(self): col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k)) @@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase): col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) - self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) + self.assertEqual( + "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql() + ) def test_map_filter(self): col_str = SF.map_filter("cola", lambda k, v: k > v) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 158dcec..7e8bfad 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator): def test_cdf_no_schema(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) - expected = ( - "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" - ) + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" self.compare_sql(df, expected) def test_cdf_row_mixed_primitives(self): @@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator): sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) df = self.spark.sql(query) self.assertIn( - "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False) + "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", + df.sql(pretty=False), ) @mock.patch("sqlglot.schema", MappingSchema()) @@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) df = self.spark.sql(query) - expected = ( - "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" - ) + expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) def test_session_create_builder_patterns(self): diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py index 1f6c5dc..52f5d72 100644 --- a/tests/dataframe/unit/test_types.py +++ b/tests/dataframe/unit/test_types.py @@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase): self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString()) def test_map(self): - self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString()) + self.assertEqual( + "map<int, string>", + types.MapType(types.IntegerType(), types.StringType()).simpleString(), + ) def test_struct_field(self): self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString()) diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py index eea4582..70a868a 100644 --- a/tests/dataframe/unit/test_window.py +++ b/tests/dataframe/unit/test_window.py @@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase): def test_window_rows_unbounded(self): rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) - self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql()) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + rows_between_unbounded_start.sql(), + ) rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) - self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql()) - rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql() + "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + rows_between_unbounded_end.sql(), + ) + rows_between_unbounded_both = Window.rowsBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + rows_between_unbounded_both.sql(), ) def test_window_range_unbounded(self): range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql() + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + range_between_unbounded_start.sql(), ) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) - self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql()) - range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql() + "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + range_between_unbounded_end.sql(), + ) + range_between_unbounded_both = Window.rangeBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + range_between_unbounded_both.sql(), ) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 050d41e..a0ebc45 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -157,6 +157,14 @@ class TestBigQuery(Validator): }, ) + self.validate_all( + "DIV(x, y)", + write={ + "bigquery": "DIV(x, y)", + "duckdb": "CAST(x / y AS INT)", + }, + ) + self.validate_identity( "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" ) @@ -284,4 +292,6 @@ class TestBigQuery(Validator): "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") - self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") + self.validate_identity( + "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 715bf10..efb41bb 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -18,7 +18,6 @@ class TestClickhouse(Validator): "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", }, ) - self.validate_all( "CAST(1 AS NULLABLE(Int64))", write={ @@ -31,3 +30,7 @@ class TestClickhouse(Validator): "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", }, ) + self.validate_all( + "SELECT x #! comment", + write={"": "SELECT x /* comment */"}, + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index e242e73..2168f55 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -22,7 +22,8 @@ class TestDatabricks(Validator): }, ) self.validate_all( - "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"} + "SELECT DATEDIFF('end', 'start')", + write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}, ) self.validate_all( "SELECT DATE_ADD('2020-01-01', 1)", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3b837df..1913f53 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,20 +1,18 @@ import unittest -from sqlglot import ( - Dialect, - Dialects, - ErrorLevel, - UnsupportedError, - parse_one, - transpile, -) +from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one class Validator(unittest.TestCase): dialect = None - def validate_identity(self, sql): - self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) + def parse_one(self, sql): + return parse_one(sql, read=self.dialect) + + def validate_identity(self, sql, write_sql=None): + expression = self.parse_one(sql) + self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) + return expression def validate_all(self, sql, read=None, write=None, pretty=False): """ @@ -28,12 +26,14 @@ class Validator(unittest.TestCase): read (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL """ - expression = parse_one(sql, read=self.dialect) + expression = self.parse_one(sql) for read_dialect, read_sql in (read or {}).items(): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( - parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE), + parse_one(read_sql, read_dialect).sql( + self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty + ), sql, ) @@ -83,10 +83,6 @@ class TestDialect(Validator): ) self.validate_all( "CAST(a AS BINARY(4))", - read={ - "presto": "CAST(a AS VARBINARY(4))", - "sqlite": "CAST(a AS VARBINARY(4))", - }, write={ "bigquery": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))", @@ -104,6 +100,24 @@ class TestDialect(Validator): }, ) self.validate_all( + "CAST(a AS VARBINARY(4))", + write={ + "bigquery": "CAST(a AS VARBINARY(4))", + "clickhouse": "CAST(a AS VARBINARY(4))", + "duckdb": "CAST(a AS VARBINARY(4))", + "mysql": "CAST(a AS VARBINARY(4))", + "hive": "CAST(a AS BINARY(4))", + "oracle": "CAST(a AS BLOB(4))", + "postgres": "CAST(a AS BYTEA(4))", + "presto": "CAST(a AS VARBINARY(4))", + "redshift": "CAST(a AS VARBYTE(4))", + "snowflake": "CAST(a AS VARBINARY(4))", + "sqlite": "CAST(a AS BLOB(4))", + "spark": "CAST(a AS BINARY(4))", + "starrocks": "CAST(a AS VARBINARY(4))", + }, + ) + self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", @@ -472,45 +486,57 @@ class TestDialect(Validator): }, ) self.validate_all( - "DATE_TRUNC(x, 'day')", + "DATE_TRUNC('day', x)", write={ "mysql": "DATE(x)", - "starrocks": "DATE(x)", }, ) self.validate_all( - "DATE_TRUNC(x, 'week')", + "DATE_TRUNC('week', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", }, ) self.validate_all( - "DATE_TRUNC(x, 'month')", + "DATE_TRUNC('month', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'quarter')", + "DATE_TRUNC('quarter', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'year')", + "DATE_TRUNC('year', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'millenium')", + "DATE_TRUNC('millenium', x)", write={ "mysql": UnsupportedError, - "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "DATE_TRUNC('year', x)", + read={ + "starrocks": "DATE_TRUNC('year', x)", + }, + write={ + "starrocks": "DATE_TRUNC('year', x)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, year)", + read={ + "bigquery": "DATE_TRUNC(x, year)", + }, + write={ + "bigquery": "DATE_TRUNC(x, year)", }, ) self.validate_all( @@ -564,6 +590,22 @@ class TestDialect(Validator): "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", }, ) + self.validate_all( + "TIMESTAMP '2022-01-01'", + write={ + "mysql": "CAST('2022-01-01' AS TIMESTAMP)", + "starrocks": "CAST('2022-01-01' AS DATETIME)", + "hive": "CAST('2022-01-01' AS TIMESTAMP)", + }, + ) + self.validate_all( + "TIMESTAMP('2022-01-01')", + write={ + "mysql": "TIMESTAMP('2022-01-01')", + "starrocks": "TIMESTAMP('2022-01-01')", + "hive": "TIMESTAMP('2022-01-01')", + }, + ) for unit in ("DAY", "MONTH", "YEAR"): self.validate_all( @@ -1002,7 +1044,10 @@ class TestDialect(Validator): ) def test_limit(self): - self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) + self.validate_all( + "SELECT * FROM data LIMIT 10, 20", + write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}, + ) self.validate_all( "SELECT x FROM y LIMIT 10", write={ @@ -1132,3 +1177,56 @@ class TestDialect(Validator): "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", }, ) + + def test_nullsafe_eq(self): + self.validate_all( + "SELECT a IS NOT DISTINCT FROM b", + read={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + write={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + ) + + def test_nullsafe_neq(self): + self.validate_all( + "SELECT a IS DISTINCT FROM b", + read={ + "postgres": "SELECT a IS DISTINCT FROM b", + }, + write={ + "mysql": "SELECT NOT a <=> b", + "postgres": "SELECT a IS DISTINCT FROM b", + }, + ) + + def test_hash_comments(self): + self.validate_all( + "SELECT 1 /* arbitrary content,,, until end-of-line */", + read={ + "mysql": "SELECT 1 # arbitrary content,,, until end-of-line", + "bigquery": "SELECT 1 # arbitrary content,,, until end-of-line", + "clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line", + }, + ) + self.validate_all( + """/* comment1 */ +SELECT + x, -- comment2 + y -- comment3""", + read={ + "mysql": """SELECT # comment1 + x, # comment2 + y # comment3""", + "bigquery": """SELECT # comment1 + x, # comment2 + y # comment3""", + "clickhouse": """SELECT # comment1 + x, # comment2 + y # comment3""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index a25871c..1ba118b 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -1,3 +1,4 @@ +from sqlglot import expressions as exp from tests.dialects.test_dialect import Validator @@ -20,6 +21,52 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") + self.validate_identity("@@GLOBAL.max_connections") + + # SET Commands + self.validate_identity("SET @var_name = expr") + self.validate_identity("SET @name = 43") + self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)") + self.validate_identity("SET GLOBAL max_connections = 1000") + self.validate_identity("SET @@GLOBAL.max_connections = 1000") + self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'") + self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@sql_mode = 'TRADITIONAL'") + self.validate_identity("SET sql_mode = 'TRADITIONAL'") + self.validate_identity("SET PERSIST max_connections = 1000") + self.validate_identity("SET @@PERSIST.max_connections = 1000") + self.validate_identity("SET PERSIST_ONLY back_log = 100") + self.validate_identity("SET @@PERSIST_ONLY.back_log = 100") + self.validate_identity("SET @@SESSION.max_join_size = DEFAULT") + self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size") + self.validate_identity("SET @x = 1, SESSION sql_mode = ''") + self.validate_identity( + "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000" + ) + self.validate_identity( + "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" + ) + self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000") + self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000") + self.validate_identity("SET CHARACTER SET 'utf8'") + self.validate_identity("SET CHARACTER SET utf8") + self.validate_identity("SET CHARACTER SET DEFAULT") + self.validate_identity("SET NAMES 'utf8'") + self.validate_identity("SET NAMES DEFAULT") + self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") + self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") + self.validate_identity("SET autocommit = ON") + + def test_escape(self): + self.validate_all( + r"'a \' b '' '", + write={ + "mysql": r"'a '' b '' '", + "spark": r"'a \' b \' '", + }, + ) def test_introducers(self): self.validate_all( @@ -115,14 +162,6 @@ class TestMySQL(Validator): }, ) - def test_hash_comments(self): - self.validate_all( - "SELECT 1 # arbitrary content,,, until end-of-line", - write={ - "mysql": "SELECT 1", - }, - ) - def test_mysql(self): self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", @@ -174,3 +213,242 @@ COMMENT='客户账户表'""" }, pretty=True, ) + + def test_show_simple(self): + for key, write_key in [ + ("BINARY LOGS", "BINARY LOGS"), + ("MASTER LOGS", "BINARY LOGS"), + ("STORAGE ENGINES", "ENGINES"), + ("ENGINES", "ENGINES"), + ("EVENTS", "EVENTS"), + ("MASTER STATUS", "MASTER STATUS"), + ("PLUGINS", "PLUGINS"), + ("PRIVILEGES", "PRIVILEGES"), + ("PROFILES", "PROFILES"), + ("REPLICAS", "REPLICAS"), + ("SLAVE HOSTS", "REPLICAS"), + ]: + show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, write_key) + + def test_show_events(self): + for key in ["BINLOG", "RELAYLOG"]: + show = self.validate_identity(f"SHOW {key} EVENTS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, f"{key} EVENTS") + + show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3") + self.assertEqual(show.text("log"), "log") + self.assertEqual(show.text("position"), "1") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1") + self.assertEqual(show.text("limit"), "1") + self.assertIsNone(show.args.get("offset")) + + def test_show_like_or_where(self): + for key, write_key in [ + ("CHARSET", "CHARACTER SET"), + ("CHARACTER SET", "CHARACTER SET"), + ("COLLATION", "COLLATION"), + ("DATABASES", "DATABASES"), + ("FUNCTION STATUS", "FUNCTION STATUS"), + ("PROCEDURE STATUS", "PROCEDURE STATUS"), + ("GLOBAL STATUS", "GLOBAL STATUS"), + ("SESSION STATUS", "STATUS"), + ("STATUS", "STATUS"), + ("GLOBAL VARIABLES", "GLOBAL VARIABLES"), + ("SESSION VARIABLES", "VARIABLES"), + ("VARIABLES", "VARIABLES"), + ]: + expected_name = write_key.strip("GLOBAL").strip() + template = "SHOW {}" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, expected_name) + + template = "SHOW {} LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + template = "SHOW {} WHERE Column_name LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_columns(self): + show = self.validate_identity("SHOW COLUMNS FROM tbl_name") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "COLUMNS") + self.assertEqual(show.text("target"), "tbl_name") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.text("target"), "tbl_name") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_show_name(self): + for key in [ + "CREATE DATABASE", + "CREATE EVENT", + "CREATE FUNCTION", + "CREATE PROCEDURE", + "CREATE TABLE", + "CREATE TRIGGER", + "CREATE VIEW", + "FUNCTION CODE", + "PROCEDURE CODE", + ]: + show = self.validate_identity(f"SHOW {key} foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + self.assertEqual(show.text("target"), "foo") + + def test_show_grants(self): + show = self.validate_identity(f"SHOW GRANTS FOR foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "GRANTS") + self.assertEqual(show.text("target"), "foo") + + def test_show_engine(self): + show = self.validate_identity("SHOW ENGINE foo STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertFalse(show.args["mutex"]) + + show = self.validate_identity("SHOW ENGINE foo MUTEX") + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertTrue(show.args["mutex"]) + + def test_show_errors(self): + for key in ["ERRORS", "WARNINGS"]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} LIMIT 2, 3") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + def test_show_index(self): + show = self.validate_identity("SHOW INDEX FROM foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "INDEX") + self.assertEqual(show.text("target"), "foo") + + show = self.validate_identity("SHOW INDEX FROM foo FROM bar") + self.assertEqual(show.text("db"), "bar") + + def test_show_db_like_or_where_sql(self): + for key in [ + "OPEN TABLES", + "TABLE STATUS", + "TRIGGERS", + ]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} FROM db_name") + self.assertEqual(show.name, key) + self.assertEqual(show.text("db"), "db_name") + + show = self.validate_identity(f"SHOW {key} LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_processlist(self): + show = self.validate_identity("SHOW PROCESSLIST") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROCESSLIST") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL PROCESSLIST") + self.assertEqual(show.name, "PROCESSLIST") + self.assertTrue(show.args["full"]) + + def test_show_profile(self): + show = self.validate_identity("SHOW PROFILE") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROFILE") + + show = self.validate_identity("SHOW PROFILE BLOCK IO") + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + + show = self.validate_identity( + "SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3" + ) + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + self.assertEqual(show.args["types"][1].name, "PAGE FAULTS") + self.assertEqual(show.text("query"), "1") + self.assertEqual(show.text("offset"), "2") + self.assertEqual(show.text("limit"), "3") + + def test_show_replica_status(self): + show = self.validate_identity("SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name") + self.assertEqual(show.text("channel"), "channel_name") + + def test_show_tables(self): + show = self.validate_identity("SHOW TABLES") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "TABLES") + + show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_set_variable(self): + cmd = self.parse_one("SET SESSION x = 1") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "SESSION") + self.assertIsInstance(item.this, exp.EQ) + self.assertEqual(item.this.left.name, "x") + self.assertEqual(item.this.right.name, "1") + + cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "") + self.assertIsInstance(item.this, exp.EQ) + self.assertIsInstance(item.this.left, exp.SessionParameter) + self.assertIsInstance(item.this.right, exp.SessionParameter) + + cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "NAMES") + self.assertEqual(item.name, "charset_name") + self.assertEqual(item.text("collate"), "collation_name") + + cmd = self.parse_one("SET CHARSET DEFAULT") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "CHARACTER SET") + self.assertEqual(item.this.name, "DEFAULT") + + cmd = self.parse_one("SET x = 1, y = 2") + self.assertEqual(len(cmd.expressions), 2) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 35141e2..8294eea 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,7 +8,9 @@ class TestPostgres(Validator): def test_ddl(self): self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"}, + write={ + "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" + }, ) self.validate_all( "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", @@ -59,15 +61,27 @@ class TestPostgres(Validator): def test_postgres(self): self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')') + self.validate_identity( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" + ) + self.validate_identity( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END" + ) + self.validate_identity( + 'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')' + ) self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") - self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')") - self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") + self.validate_identity( + "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')" + ) + self.validate_identity( + "SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))" + ) self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") - self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") + self.validate_identity( + "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')" + ) self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") @@ -75,7 +89,7 @@ class TestPostgres(Validator): self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ - "duckdb": "CREATE TABLE x (a UUID, b BINARY)", + "duckdb": "CREATE TABLE x (a UUID, b VARBINARY)", "presto": "CREATE TABLE x (a UUID, b VARBINARY)", "hive": "CREATE TABLE x (a UUID, b BINARY)", "spark": "CREATE TABLE x (a UUID, b BINARY)", @@ -153,7 +167,9 @@ class TestPostgres(Validator): ) self.validate_all( "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", - read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"}, + read={ + "postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss" + }, ) self.validate_all( "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", @@ -169,11 +185,15 @@ class TestPostgres(Validator): ) self.validate_all( "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"}, + read={ + "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL" + }, ) self.validate_all( "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"}, + read={ + "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL" + }, ) self.validate_all( "'[1,2,3]'::json->2", @@ -184,7 +204,8 @@ class TestPostgres(Validator): write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""}, ) self.validate_all( - """'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""} + """'{"x": {"y": 1}}'::json->'x'->'y'""", + write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'::json->'y'""", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 1ed2bb6..5309a34 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -61,4 +61,6 @@ class TestRedshift(Validator): "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'" ) self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)") - self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'") + self.validate_identity( + "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index fea2311..1846b17 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -336,7 +336,8 @@ class TestSnowflake(Validator): def test_table_literal(self): # All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html self.validate_all( - r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""} + r"""SELECT * FROM TABLE('MYTABLE')""", + write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}, ) self.validate_all( @@ -352,15 +353,123 @@ class TestSnowflake(Validator): write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""}, ) - self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}) + self.validate_all( + r"""SELECT * FROM TABLE($MYVAR)""", + write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}, + ) - self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}) + self.validate_all( + r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""} + ) self.validate_all( - r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""} + r"""SELECT * FROM TABLE(:BINDING)""", + write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}, ) self.validate_all( r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""}, ) + + def test_flatten(self): + self.validate_all( + """ + select + dag_report.acct_id, + dag_report.report_date, + dag_report.report_uuid, + dag_report.airflow_name, + dag_report.dag_id, + f.value::varchar as operator + from cs.telescope.dag_report, + table(flatten(input=>split(operators, ','))) f + """, + write={ + "snowflake": """SELECT + dag_report.acct_id, + dag_report.report_date, + dag_report.report_uuid, + dag_report.airflow_name, + dag_report.dag_id, + CAST(f.value AS VARCHAR) AS operator +FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f""" + }, + pretty=True, + ) + + # All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax + self.validate_all( + "SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f", + write={ + "snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""", + write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""}, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f""" + }, + ) + + self.validate_all( + """ + SELECT id as "ID", + f.value AS "Contact", + f1.value:type AS "Type", + f1.value:content AS "Details" + FROM persons p, + lateral flatten(input => p.c, path => 'contact') f, + lateral flatten(input => f.value:business) f1 + """, + write={ + "snowflake": """SELECT + id AS "ID", + f.value AS "Contact", + f1.value['type'] AS "Type", + f1.value['content'] AS "Details" +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8605bd1..4470722 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -284,4 +284,6 @@ TBLPROPERTIES ( ) def test_iif(self): - self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}) + self.validate_all( + "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"} + ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 1fe1a57..35d8b45 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -6,3 +6,6 @@ class TestMySQL(Validator): def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + + def test_time(self): + self.validate_identity("TIMESTAMP('2022-01-01')") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d22a9c2..a60f48d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -278,12 +278,19 @@ class TestTSQL(Validator): def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") self.validate_all( - "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} + "SELECT DATEADD(year, 1, '2017/08/25')", + write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}, + ) + self.validate_all( + "SELECT DATEADD(qq, 1, '2017/08/25')", + write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}, ) - self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) self.validate_all( "SELECT DATEADD(wk, 1, '2017/08/25')", - write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"}, + write={ + "spark": "SELECT DATE_ADD('2017/08/25', 7)", + "databricks": "SELECT DATEADD(week, 1, '2017/08/25')", + }, ) def test_date_diff(self): @@ -370,13 +377,21 @@ class TestTSQL(Validator): "SELECT FORMAT(1000000.01,'###,###.###')", write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, ) - self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}) + self.validate_all( + "SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"} + ) self.validate_all( "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"}, ) self.validate_all( - "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"} + "SELECT FORMAT(date_col, 'dd.mm.yyyy')", + write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}, + ) + self.validate_all( + "SELECT FORMAT(date_col, 'm')", + write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}, + ) + self.validate_all( + "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} ) - self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}) - self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index d7084ac..836ab28 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -523,6 +523,8 @@ DROP VIEW a.b DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a.b SHOW TABLES +USE db +ROLLBACK EXPLAIN SELECT * FROM x INSERT INTO x SELECT * FROM y INSERT INTO x (SELECT * FROM y) @@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1) SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3) SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) +SELECT CAST(x AS INT) /* comment */ FROM foo +SELECT a /* x */, b /* x */ +SELECT * FROM foo /* x */, bla /* x */ +SELECT 1 /* comment */ + 1 +SELECT 1 /* c1 */ + 2 /* c2 */ +SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */ +SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ +SELECT x FROM a.b.c /* x */, e.f.g /* x */ +SELECT FOO(x /* c */) /* FOO */, b /* b */ +SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index a958c08..1176078 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_ SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x; +# dialect: starrocks +# execute: false +SELECT DATE_TRUNC('week', a) AS a FROM x; +SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x; + +# dialect: bigquery +# execute: false +SELECT DATE_TRUNC(a, MONTH) AS a FROM x; +SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x; + -------------------------------------- -- Derived tables -------------------------------------- diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 07e818f..7207ba2 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -79,6 +79,15 @@ NULL; NULL = NULL; NULL; +NULL <=> NULL; +TRUE; + +a IS NOT DISTINCT FROM a; +TRUE; + +NULL IS DISTINCT FROM NULL; +FALSE; + NOT (NOT TRUE); TRUE; diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 2570650..5e27b5e 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -287,3 +287,31 @@ SELECT "fffffff" ) ); +/* + multi + line + comment +*/ +SELECT * FROM foo; +/* + multi + line + comment +*/ +SELECT + * +FROM foo; +SELECT x FROM a.b.c /*x*/, e.f.g /*x*/; +SELECT + x +FROM a.b.c /* x */, e.f.g /* x */; +SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/; +SELECT + x +FROM ( + SELECT + * + FROM bla /* x */ + WHERE + id = 1 +) /* x */; diff --git a/tests/test_build.py b/tests/test_build.py index b7b6865..721c868 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT OUTER JOIN tbl2", ), ( - lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"), + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer"), "SELECT x FROM tbl LEFT OUTER JOIN tbl2", ), ( - lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", ), ( - lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"), + lambda: select("x") + .from_("tbl") + .join(select("y").from_("tbl2"), join_type="left outer"), "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", ), ( @@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", ), ( - lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"), + lambda: select("x") + .from_("tbl") + .join(parse_one("left join x", into=exp.Join), on="a=b"), "SELECT x FROM tbl LEFT JOIN x ON a = b", ), ( @@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT JOIN x ON a = b", ), ( - lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"), + lambda: select("x") + .from_("tbl") + .join("select b from tbl2", on="a=b", join_type="left"), "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", ), ( @@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase): ( lambda: select("x", "y", "z") .from_("merged_df") - .join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]), + .join( + "vte_diagnosis_df", + using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")], + ), "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", ), ( @@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase): "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", ), ( - lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"), + lambda: select("x", "y", "z", "a") + .from_("tbl") + .cluster_by("x, y", "z") + .cluster_by("a"), "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", ), ( @@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True), + lambda: select("x") + .from_("tbl") + .with_("tbl", as_="SELECT x FROM tbl2", recursive=True), "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( @@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), + lambda: select("x") + .from_("tbl") + .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", ), ( @@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", ), ( - lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"), + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x", "y").from_("tbl2")) + .select("y"), "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", ), ( @@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .group_by("x"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .order_by("x"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .limit(10), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .offset(10), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .join("tbl3"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .distinct(), "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .where("x > 10"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .having("x > 20"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", ), (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), @@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase): "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", ), ( - lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"), + lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select( + "x" + ), "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", ), ( diff --git a/tests/test_executor.py b/tests/test_executor.py index ef1a706..49805b9 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase): ) cls.cache = {} - cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")] + cls.sqls = [ + (sql, expected) + for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") + ] @classmethod def tearDownClass(cls): @@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase): def test_execute_tpch(self): def to_csv(expression): if isinstance(expression, exp.Table): - return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}") + return parse_one( + f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" + ) return expression for sql, _ in self.sqls[0:3]: diff --git a/tests/test_expressions.py b/tests/test_expressions.py index adfd329..63371d8 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) self.assertEqual(exp.Table(pivots=[]), exp.Table()) self.assertNotEqual(exp.Table(pivots=[None]), exp.Table()) - self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)) + self.assertEqual( + exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False) + ) def test_find(self): expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") @@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase): self.assertIsNone(column.find_ancestor(exp.Join)) def test_alias_or_name(self): - expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "e", "*", "zz", "z"], @@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase): "SELECT * FROM foo WHERE ? > 100", ) self.assertEqual( - exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(), + exp.replace_placeholders( + parse_one("select * from :name WHERE ? > 100"), another_name="bla" + ).sql(), "SELECT * FROM :name WHERE ? > 100", ) self.assertEqual( @@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase): ) def test_named_selects(self): - expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) expression = parse_one( @@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9) self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) - self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) + ) def test_functions(self): self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) @@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase): ), exp.Properties( expressions=[ - exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")), + exp.FileFormatProperty( + this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet") + ), exp.PartitionedByProperty( this=exp.Literal.string("PARTITIONED_BY"), - value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]), + value=exp.Tuple( + expressions=[exp.to_identifier("a"), exp.to_identifier("b")] + ), + ), + exp.AnonymousProperty( + this=exp.Literal.string("custom"), value=exp.Literal.number(1) ), - exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)), exp.TableFormatProperty( - this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format") + this=exp.Literal.string("TABLE_FORMAT"), + value=exp.to_identifier("test_format"), ), exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), @@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase): ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), ({"x": None}, "MAP('x', NULL)"), - (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"), + ( + datetime.datetime(2022, 10, 1, 1, 1, 1), + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')", + ), ( datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", @@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase): with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) - def test_annotation_alias(self): - sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo" + def test_comment_alias(self): + sql = """ + SELECT + a, + b AS B, + c, /*comment*/ + d AS D, -- another comment + CAST(x AS INT) -- final comment + FROM foo + """ expression = parse_one(sql) self.assertEqual( [e.alias_or_name for e in expression.expressions], - ["a", "B", "c", "D"], + ["a", "B", "c", "D", "x"], + ) + self.assertEqual( + expression.sql(), + "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo", + ) + self.assertEqual( + expression.sql(comments=False), + "SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo", ) - self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D") - self.assertEqual(expression.expressions[2].name, "comment") self.assertEqual( - expression.sql(pretty=True, annotations=False), + expression.sql(pretty=True, comments=False), """SELECT a, b AS B, c, - d AS D""", + d AS D, + CAST(x AS INT) +FROM foo""", ) self.assertEqual( expression.sql(pretty=True), """SELECT a, b AS B, - c # comment, - d AS D # another_comment FROM foo""", + c, -- comment + d AS D, -- another comment + CAST(x AS INT) -- final comment +FROM foo""", ) def test_to_table(self): @@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(expression, exp.Union) self.assertEqual(expression.named_selects, ["cola", "colb"]) self.assertEqual( - expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))] + expression.selects, + [ + exp.Column(this=exp.to_identifier("cola")), + exp.Column(this=exp.to_identifier("colb")), + ], ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 3b5990f..a1b7e70 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase): } def check_file(self, file, func, pretty=False, execute=False, **kwargs): - for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1 + ): title = meta.get("title") or f"{i}, {sql}" dialect = meta.get("dialect") leave_tables_isolated = meta.get("leave_tables_isolated") @@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase): if string_to_bool(should_execute): with self.subTest(f"(execute) {title}"): - df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df() + df1 = self.conn.execute( + sqlglot.transpile(sql, read=dialect, write="duckdb")[0] + ).df() df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() assert_frame_equal(df1, df2) @@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") self.assertEqual( - scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" + scopes[3].expression.sql(), + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", ) self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") @@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') # Check that we can walk in scope from an arbitrary node self.assertEqual( - {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, + { + node.sql() + for node, *_ in walk_in_scope(expression.find(exp.Where)) + if isinstance(node, exp.Column) + }, {"s.b"}, ) @@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) def test_cache_annotation(self): - expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")) + expression = annotate_types( + parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") + ) self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) def test_binary_annotation(self): @@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = annotate_types(parse_one(sql), schema=schema) - self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col + self.assertEqual( + expression.expressions[0].type, exp.DataType.Type.TEXT + ) # tbl.cola + tbl.colb + 'foo' AS col outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) @@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) cte_select = expression.args["with"].expressions[0].this - self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola + self.assertEqual( + cte_select.expressions[0].type, exp.DataType.Type.VARCHAR + ) # x.cola + 'bla' AS cola self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' @@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively - for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]): + for d, t in zip( + cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] + ): self.assertEqual(d.this.expressions[0].this.type, t) def test_function_annotation(self): @@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" + + case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) + + case_expr = case_expr_alias.this + self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) + + case_ifs_expr = case_expr.args["ifs"][0] + self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) + def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" @@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') concat_expr = concat_expr_alias.this self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola - self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola) - self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg) + self.assertEqual( + concat_expr.right.type, exp.DataType.Type.UNKNOWN + ) # SOME_ANONYMOUS_FUNC(x.cola) + self.assertEqual( + concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR + ) # x.cola (arg) def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this diff --git a/tests/test_parser.py b/tests/test_parser.py index 9afeae6..04c20b1 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -23,8 +23,6 @@ class TestParser(unittest.TestCase): def test_float(self): self.assertEqual(parse_one(".2"), parse_one("0.2")) - self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) - self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)")) def test_table(self): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] @@ -33,7 +31,9 @@ class TestParser(unittest.TestCase): def test_select(self): self.assertIsNotNone(parse_one("select 1 natural")) self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"]) - self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"]) + self.assertIsNotNone( + parse_one("select * from x where a = (select 1) order by x.y").args["order"] + ) self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1) self.assertEqual( parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), @@ -125,26 +125,70 @@ class TestParser(unittest.TestCase): def test_var(self): self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") - def test_annotations(self): + def test_comments(self): expression = parse_one( """ - SELECT - a #annotation1, - b as B #annotation2:testing , - "test#annotation",c#annotation3, d #annotation4, - e #, - f # space + --comment1 + SELECT /* this won't be used */ + a, --comment2 + b as B, --comment3:testing + "test--annotation", + c, --comment4 --foo + e, -- + f -- space FROM foo """ ) - assert expression.expressions[0].name == "annotation1" - assert expression.expressions[1].name == "annotation2:testing" - assert expression.expressions[2].name == "test#annotation" - assert expression.expressions[3].name == "annotation3" - assert expression.expressions[4].name == "annotation4" - assert expression.expressions[5].name == "" - assert expression.expressions[6].name == "space" + self.assertEqual(expression.comment, "comment1") + self.assertEqual(expression.expressions[0].comment, "comment2") + self.assertEqual(expression.expressions[1].comment, "comment3:testing") + self.assertEqual(expression.expressions[2].comment, None) + self.assertEqual(expression.expressions[3].comment, "comment4 --foo") + self.assertEqual(expression.expressions[4].comment, "") + self.assertEqual(expression.expressions[5].comment, " space") + + def test_type_literals(self): + self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) + self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)")) + self.assertEqual( + parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)" + ) + self.assertEqual( + parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))" + ) + self.assertEqual( + parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPTZ)", + ) + self.assertEqual( + parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPLTZ)", + ) + self.assertEqual( + parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMP)", + ) + self.assertEqual( + parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPTZ(1))", + ) + self.assertEqual( + parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPLTZ(1))", + ) + self.assertEqual( + parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMP(1))", + ) + self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)") + self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)") + self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)") + self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""") + self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func) + self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func) + self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func) + self.assertIsInstance(parse_one("map.x"), exp.Column) def test_pretty_config_override(self): self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") diff --git a/tests/test_schema.py b/tests/test_schema.py index bab97d8..cc0e3d1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,281 +1,141 @@ import unittest -from sqlglot import table -from sqlglot.dataframe.sql import types as df_types +from sqlglot import exp, to_table +from sqlglot.errors import SchemaError from sqlglot.schema import MappingSchema, ensure_schema class TestSchema(unittest.TestCase): - def test_schema(self): - schema = ensure_schema( - { - "x": { - "a": "uint64", - } - } - ) - self.assertEqual( - schema.column_names( - table( - "x", - ) - ), - ["a"], - ) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x2")) + def assert_column_names(self, schema, *table_results): + for table, result in table_results: + with self.subTest(f"{table} -> {result}"): + self.assertEqual(schema.column_names(to_table(table)), result) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + def assert_column_names_raises(self, schema, *tables): + for table in tables: + with self.subTest(table): + with self.assertRaises(SchemaError): + schema.column_names(to_table(table)) - schema.add_table(table("y"), {"b": "string"}) - schema_with_y = { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - } - self.assertEqual(schema.schema, schema_with_y) - - new_schema = schema.copy() - new_schema.add_table(table("z"), {"c": "string"}) - self.assertEqual(schema.schema, schema_with_y) - self.assertEqual( - new_schema.schema, - { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - "z": { - "c": "string", - }, - }, - ) - schema.add_table(table("m"), {"d": "string"}) - schema.add_table(table("n"), {"e": "string"}) - schema_with_m_n = { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - "m": { - "d": "string", - }, - "n": { - "e": "string", - }, - } - self.assertEqual(schema.schema, schema_with_m_n) - new_schema = schema.copy() - new_schema.add_table(table("o"), {"f": "string"}) - new_schema.add_table(table("p"), {"g": "string"}) - self.assertEqual(schema.schema, schema_with_m_n) - self.assertEqual( - new_schema.schema, + def test_schema(self): + schema = ensure_schema( { "x": { "a": "uint64", }, "y": { - "b": "string", - }, - "m": { - "d": "string", - }, - "n": { - "e": "string", - }, - "o": { - "f": "string", - }, - "p": { - "g": "string", + "b": "uint64", + "c": "uint64", }, }, ) - schema = ensure_schema( - { - "db": { - "x": { - "a": "uint64", - } - } - } + self.assert_column_names( + schema, + ("x", ["a"]), + ("y", ["b", "c"]), + ("z.x", ["a"]), + ("z.x.y", ["b", "c"]), ) - self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - with self.assertRaises(ValueError): - schema.add_table(table("y"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + self.assert_column_names_raises( + schema, + "z", + "z.z", + "z.z.z", + ) - schema.add_table(table("y", db="db"), {"b": "string"}) - self.assertEqual( - schema.schema, + def test_schema_db(self): + schema = ensure_schema( { - "db": { + "d1": { "x": { "a": "uint64", }, "y": { - "b": "string", + "b": "uint64", + }, + }, + "d2": { + "x": { + "c": "uint64", }, - } + }, }, ) - schema = ensure_schema( - { - "c": { - "db": { - "x": { - "a": "uint64", - } - } - } - } + self.assert_column_names( + schema, + ("d1.x", ["a"]), + ("d2.x", ["c"]), + ("y", ["b"]), + ("d1.y", ["b"]), + ("z.d1.y", ["b"]), ) - self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c2")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - with self.assertRaises(ValueError): - schema.add_table(table("x"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("x", db="db"), {"b": "string"}) + self.assert_column_names_raises( + schema, + "x", + "z.x", + "z.y", + ) - schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"}) - self.assertEqual( - schema.schema, + def test_schema_catalog(self): + schema = ensure_schema( { - "c": { - "db": { + "c1": { + "d1": { "x": { "a": "uint64", }, "y": { - "a": "string", - "b": "int", + "b": "uint64", }, - } - } - }, - ) - schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"}) - self.assertEqual( - schema.schema, - { - "c": { - "db": { - "x": { - "a": "uint64", + "z": { + "c": "uint64", }, + }, + }, + "c2": { + "d1": { "y": { - "a": "string", - "b": "int", + "d": "uint64", }, - }, - "db2": { "z": { - "c": "string", - "d": "int", - } - }, - } - }, - ) - schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"}) - self.assertEqual( - schema.schema, - { - "c": { - "db": { - "x": { - "a": "uint64", - }, - "y": { - "a": "string", - "b": "int", + "e": "uint64", }, }, - "db2": { + "d2": { "z": { - "c": "string", - "d": "int", - } + "f": "uint64", + }, }, }, - "c2": { - "db2": { - "m": { - "e": "string", - "f": "int", - } - } - }, - }, - ) - - schema = ensure_schema( - { - "x": { - "a": "uint64", - } } ) - self.assertEqual(schema.column_names(table("x")), ["a"]) - schema = MappingSchema() - schema.add_table(table("x"), {"a": "string"}) - self.assertEqual( - schema.schema, - { - "x": { - "a": "string", - } - }, + self.assert_column_names( + schema, + ("x", ["a"]), + ("d1.x", ["a"]), + ("c1.d1.x", ["a"]), + ("c1.d1.y", ["b"]), + ("c1.d1.z", ["c"]), + ("c2.d1.y", ["d"]), + ("c2.d1.z", ["e"]), + ("d2.z", ["f"]), + ("c2.d2.z", ["f"]), ) - schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())])) - self.assertEqual( - schema.schema, - { - "x": { - "a": "string", - }, - "y": { - "b": "string", - }, - }, + + self.assert_column_names_raises( + schema, + "q", + "d2.x", + "y", + "z", + "d1.y", + "d1.z", + "a.b.c", ) def test_schema_add_table_with_and_without_mapping(self): @@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase): self.assertEqual(schema.column_names("test"), ["x", "y"]) schema.add_table("test") self.assertEqual(schema.column_names("test"), ["x", "y"]) + + def test_schema_get_column_type(self): + schema = MappingSchema({"a": {"b": "varchar"}}) + self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR) + self.assertEqual( + schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR + ) + schema = MappingSchema({"a": {"b": {"c": "varchar"}}}) + self.assertEqual( + schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR + ) + schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}}) + self.assertEqual( + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"), + exp.DataType.Type.VARCHAR, + ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py new file mode 100644 index 0000000..943c2b0 --- /dev/null +++ b/tests/test_tokens.py @@ -0,0 +1,18 @@ +import unittest + +from sqlglot.tokens import Tokenizer + + +class TestTokens(unittest.TestCase): + def test_comment_attachment(self): + tokenizer = Tokenizer() + sql_comment = [ + ("/*comment*/ foo", "comment"), + ("/*comment*/ foo --test", "comment"), + ("--comment\nfoo --test", "comment"), + ("foo --comment", "comment"), + ("foo", None), + ] + + for sql, comment in sql_comment: + self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 01b8205..942053e 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase): leading_comma=True, pretty=True, ) + self.validate( + "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", + "SELECT\n FOO -- x\n , BAR -- y\n , BAZ", + leading_comma=True, + pretty=True, + ) # without pretty, this should be a no-op self.validate( "SELECT FOO, BAR, BAZ", @@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): - self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo") - self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo") - + self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") + self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") + self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo") + self.validate( + "SELECT 1 /* inline */ FROM foo -- comment", + "SELECT 1 /* inline */ FROM foo /* comment */", + ) + self.validate( + "SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */" + ) self.validate( """ SELECT 1 -- comment FROM foo -- comment """, - "SELECT 1 FROM foo", + "SELECT 1 /* comment */ FROM foo /* comment */", ) - self.validate( """ SELECT 1 /* big comment like this */ FROM foo -- comment """, - "SELECT 1 FROM foo", + """SELECT 1 /* big comment + like this */ FROM foo /* comment */""", + ) + self.validate( + "select x from foo -- x", + "SELECT x FROM foo /* x */", + ) + self.validate( + """ + /* multi + line + comment + */ + SELECT + tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, + CAST(x AS INT), # comment 3 + y -- comment 4 + FROM + bar /* comment 5 */, + tbl # comment 6 + """, + """/* multi + line + comment + */ +SELECT + tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, + CAST(x AS INT), -- comment 3 + y -- comment 4 +FROM bar /* comment 5 */, tbl /* comment 6 */""", + read="mysql", + pretty=True, ) def test_types(self): @@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase): def test_ignore_nulls(self): self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") + def test_with(self): + self.validate( + "WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *", + "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *", + ) + self.validate( + "WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *", + "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *", + ) + def test_time(self): self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") |