From 1a60bbae98d3b530924a6807a55f8250de19ea86 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 2 Dec 2022 10:16:29 +0100 Subject: Adding upstream version 10.1.3. Signed-off-by: Daniel Baumann --- CHANGELOG.md | 41 +++ README.md | 76 ++++- sqlglot/__init__.py | 2 +- sqlglot/dialects/bigquery.py | 20 +- sqlglot/dialects/clickhouse.py | 13 +- sqlglot/dialects/dialect.py | 21 +- sqlglot/dialects/drill.py | 2 +- sqlglot/dialects/hive.py | 14 +- sqlglot/dialects/mysql.py | 1 + sqlglot/dialects/oracle.py | 11 +- sqlglot/dialects/postgres.py | 48 +++- sqlglot/dialects/presto.py | 14 +- sqlglot/dialects/redshift.py | 18 +- sqlglot/dialects/snowflake.py | 25 +- sqlglot/dialects/spark.py | 2 +- sqlglot/dialects/sqlite.py | 18 ++ sqlglot/dialects/tsql.py | 41 ++- sqlglot/errors.py | 41 ++- sqlglot/executor/env.py | 1 - sqlglot/executor/python.py | 46 ++-- sqlglot/expressions.py | 111 ++++---- sqlglot/generator.py | 120 ++++---- sqlglot/optimizer/eliminate_subqueries.py | 59 +++- sqlglot/optimizer/lower_identities.py | 92 +++++++ sqlglot/optimizer/optimizer.py | 2 + sqlglot/optimizer/unnest_subqueries.py | 36 ++- sqlglot/parser.py | 321 ++++++++++++---------- sqlglot/planner.py | 52 ++-- sqlglot/tokens.py | 45 +-- sqlglot/transforms.py | 40 +++ tests/dataframe/unit/test_functions.py | 2 +- tests/dialects/test_clickhouse.py | 4 + tests/dialects/test_dialect.py | 11 +- tests/dialects/test_duckdb.py | 6 + tests/dialects/test_hive.py | 5 +- tests/dialects/test_mysql.py | 19 +- tests/dialects/test_postgres.py | 21 ++ tests/dialects/test_presto.py | 7 + tests/dialects/test_redshift.py | 16 ++ tests/dialects/test_snowflake.py | 20 +- tests/dialects/test_spark.py | 11 +- tests/dialects/test_sqlite.py | 4 + tests/dialects/test_tsql.py | 28 +- tests/fixtures/identity.sql | 7 + tests/fixtures/optimizer/eliminate_subqueries.sql | 12 + tests/fixtures/optimizer/lower_identities.sql | 41 +++ tests/fixtures/optimizer/optimizer.sql | 15 + tests/fixtures/optimizer/simplify.sql | 9 + tests/fixtures/optimizer/tpc-h/tpc-h.sql | 57 ++-- tests/fixtures/optimizer/unnest_subqueries.sql | 84 +++--- tests/test_executor.py | 71 ++++- tests/test_expressions.py | 28 +- tests/test_optimizer.py | 6 +- tests/test_parser.py | 67 ++++- tests/test_tokens.py | 14 +- tests/test_transforms.py | 29 +- tests/test_transpile.py | 120 +++++++- 57 files changed, 1530 insertions(+), 517 deletions(-) create mode 100644 sqlglot/optimizer/lower_identities.py create mode 100644 tests/fixtures/optimizer/lower_identities.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 70f2b55..a439c2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,47 @@ Changelog ========= +v10.1.0 +------ + +Changes: + +- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/6b0da1e1a2b5d6bdf7b5b918400456422d30a1d4) the way SQL comments are handled. Before at most one comment could be attached to an expression, now multiple comments may be stored in a list. + +- Breaking: [refactored](https://github.com/tobymao/sqlglot/commit/be332d10404f36b43ea6ece956a73bf451348641) the way properties are represented and parsed. The argument `this` now stores a property's attributes instead of its name. + +- New: added structured ParseError properties. + +- New: the executor now handles set operations. + +- New: sqlglot can [now execute SQL queries](https://github.com/tobymao/sqlglot/commit/62d3496e761a4f38dfa61af793062690923dce74) using python objects. + +- New: added support for the [Drill dialect](https://github.com/tobymao/sqlglot/commit/543eca314546e0bd42f97c354807b4e398ab36ec). + +- New: added a `canonicalize` method which leverages existing type information for an expression to apply various transformations to it. + +- New: TRIM function support for Snowflake and Bigquery. + +- New: added support for SQLite primary key ordering constraints (ASC, DESC). + +- New: added support for Redshift DISTKEY / SORTKEY / DISTSTYLE properties. + +- New: added support for SET TRANSACTION MySQL statements. + +- New: added `null`, `true`, `false` helper methods to avoid using singleton expressions. + +- Improvement: allow multiple aggregations in an expression. + +- Improvement: execution of left / right joins. + +- Improvement: execution of aggregations without the GROUP BY clause. + +- Improvement: static query execution (e.g. SELECT 1, SELECT CONCAT('a', 'b') AS x, etc). + +- Improvement: include a rule for type inference in the optimizer. + +- Improvement: transaction, commit expressions parsed [at finer granularity](https://github.com/tobymao/sqlglot/commit/148282e710fd79512bb7d32e6e519d631df8115d). + v10.0.0 ------ diff --git a/README.md b/README.md index 2ceadfb..218d86c 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/ * [AST Introspection](#ast-introspection) * [AST Diff](#ast-diff) * [Custom Dialects](#custom-dialects) + * [SQL Execution](#sql-execution) * [Benchmarks](#benchmarks) * [Optional Dependencies](#optional-dependencies) @@ -147,9 +148,9 @@ print(sqlglot.transpile(sql, read='mysql', pretty=True)[0]) */ SELECT tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, - CAST(x AS INT), -- comment 3 - y -- comment 4 -FROM bar /* comment 5 */, tbl /* comment 6*/ + CAST(x AS INT), /* comment 3 */ + y /* comment 4 */ +FROM bar /* comment 5 */, tbl /* comment 6 */ ``` @@ -189,6 +190,28 @@ sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13. ~~~~ ``` +Structured syntax errors are accessible for programmatic use: + +```python +import sqlglot +try: + sqlglot.transpile("SELECT foo( FROM bar") +except sqlglot.errors.ParseError as e: + print(e.errors) +``` + +Output: +```python +[{ + 'description': 'Expecting )', + 'line': 1, + 'col': 13, + 'start_context': 'SELECT foo( ', + 'highlight': 'FROM', + 'end_context': ' bar' +}] +``` + ### Unsupported Errors Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive: @@ -372,6 +395,53 @@ print(Dialect["custom"]) ``` +### SQL Execution + +One can even interpret SQL queries using SQLGlot, where the tables are represented as Python dictionaries. Although the engine is not very fast (it's not supposed to be) and is in a relatively early stage of development, it can be useful for unit testing and running SQL natively across Python objects. Additionally, the foundation can be easily integrated with fast compute kernels (arrow, pandas). Below is an example showcasing the execution of a SELECT expression that involves aggregations and JOINs: + +```python +from sqlglot.executor import execute + +tables = { + "sushi": [ + {"id": 1, "price": 1.0}, + {"id": 2, "price": 2.0}, + {"id": 3, "price": 3.0}, + ], + "order_items": [ + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 2, "order_id": 1}, + {"sushi_id": 3, "order_id": 2}, + ], + "orders": [ + {"id": 1, "user_id": 1}, + {"id": 2, "user_id": 2}, + ], +} + +execute( + """ + SELECT + o.user_id, + SUM(s.price) AS price + FROM orders o + JOIN order_items i + ON o.id = i.order_id + JOIN sushi s + ON i.sushi_id = s.id + GROUP BY o.user_id + """, + tables=tables +) +``` + +```python +user_id price + 1 4.0 + 2 3.0 +``` + ## Benchmarks [Benchmarks](benchmarks) run on Python 3.10.5 in seconds. diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 50e2d9c..b027ac7 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -30,7 +30,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.0.8" +__version__ = "10.1.3" pretty = False diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 4550d65..5b44912 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression): def _returnsproperty_sql(self, expression): - value = expression.args.get("value") - if isinstance(value, exp.Schema): - value = f"{value.this} <{self.expressions(value)}>" + this = expression.this + if isinstance(this, exp.Schema): + this = f"{this.this} <{self.expressions(this)}>" else: - value = self.sql(value) - return f"RETURNS {value}" + this = self.sql(this) + return f"RETURNS {this}" def _create_sql(self, expression): @@ -142,6 +142,11 @@ class BigQuery(Dialect): ), } + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + } + FUNCTION_PARSERS.pop("TRIM") + NO_PAREN_FUNCTIONS = { **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, @@ -174,6 +179,7 @@ class BigQuery(Dialect): exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, + exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -200,9 +206,7 @@ class BigQuery(Dialect): exp.VolatilityProperty, } - WITH_PROPERTIES = { - exp.AnonymousProperty, - } + WITH_PROPERTIES = {exp.Property} EXPLICIT_UNION = True diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 332b4c1..cbed72e 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -21,14 +21,15 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "FINAL": TokenType.FINAL, + "ASOF": TokenType.ASOF, "DATETIME64": TokenType.DATETIME, - "INT8": TokenType.TINYINT, + "FINAL": TokenType.FINAL, + "FLOAT32": TokenType.FLOAT, + "FLOAT64": TokenType.DOUBLE, "INT16": TokenType.SMALLINT, "INT32": TokenType.INT, "INT64": TokenType.BIGINT, - "FLOAT32": TokenType.FLOAT, - "FLOAT64": TokenType.DOUBLE, + "INT8": TokenType.TINYINT, "TUPLE": TokenType.STRUCT, } @@ -38,6 +39,10 @@ class ClickHouse(Dialect): "MAP": parse_var_map, } + JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} + + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} + def _parse_table(self, schema=False): this = super()._parse_table(schema) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8c497ab..c87f8d8 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -289,19 +289,19 @@ def struct_extract_sql(self, expression): return f"{this}.{struct_key}" -def var_map_sql(self, expression): +def var_map_sql(self, expression, map_func_name="MAP"): keys = expression.args["keys"] values = expression.args["values"] if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): self.unsupported("Cannot convert array columns into map.") - return f"MAP({self.format_args(keys, values)})" + return f"{map_func_name}({self.format_args(keys, values)})" args = [] for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) - return f"MAP({self.format_args(*args)})" + return f"{map_func_name}({self.format_args(*args)})" def format_time_lambda(exp_class, dialect, default=None): @@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression): if has_schema and is_partitionable: expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) - value = prop and prop.args.get("value") - if prop and not isinstance(value, exp.Schema): + this = prop and prop.this + if prop and not isinstance(this, exp.Schema): schema = expression.this - columns = {v.name.upper() for v in value.expressions} + columns = {v.name.upper() for v in this.expressions} partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set( - "expressions", - [e for e in schema.expressions if e not in partitions], - ) - prop.replace( - exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) - ) + schema.set("expressions", [e for e in schema.expressions if e not in partitions]) + prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) expression.set("this", schema) return self.create_sql(expression) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index eb420aa..358eced 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -153,7 +153,7 @@ class Drill(Dialect): exp.If: if_sql, exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index cff7139..cbb39c2 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -61,9 +61,7 @@ def _array_sort(self, expression): def _property_sql(self, expression): - key = expression.name - value = self.sql(expression, "value") - return f"'{key}'={value}" + return f"'{expression.name}'={self.sql(expression, 'value')}" def _str_to_unix(self, expression): @@ -250,7 +248,7 @@ class Hive(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, **transforms.UNALIAS_GROUP, # type: ignore - exp.AnonymousProperty: _property_sql, + exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayConcat: rename_func("CONCAT"), @@ -262,7 +260,7 @@ class Hive(Dialect): exp.DateStrToDate: rename_func("TO_DATE"), exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", + exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}", exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, @@ -285,7 +283,7 @@ class Hive(Dialect): exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", + exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -298,11 +296,11 @@ class Hive(Dialect): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.NumberToStr: rename_func("FORMAT_NUMBER"), } - WITH_PROPERTIES = {exp.AnonymousProperty} + WITH_PROPERTIES = {exp.Property} ROOT_PROPERTIES = { exp.PartitionedByProperty, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 93a60f4..7627b6e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -453,6 +453,7 @@ class MySQL(Dialect): exp.CharacterSetProperty, exp.CollateProperty, exp.SchemaCommentProperty, + exp.LikeProperty, } WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 870d2b9..ceaf9ba 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,7 +1,7 @@ from __future__ import annotations -from sqlglot import exp, generator, tokens, transforms -from sqlglot.dialects.dialect import Dialect, no_ilike_sql +from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func from sqlglot.helper import csv from sqlglot.tokens import TokenType @@ -37,6 +37,12 @@ class Oracle(Dialect): "YYYY": "%Y", # 2015 } + class Parser(parser.Parser): + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "DECODE": exp.Matches.from_arg_list, + } + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -58,6 +64,7 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 4353164..1cb5025 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -74,6 +74,27 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" +def _string_agg_sql(self, expression): + expression = expression.copy() + separator = expression.args.get("separator") or exp.Literal.string(",") + + order = "" + this = expression.this + if isinstance(this, exp.Order): + if this.this: + this = this.this + this.pop() + order = self.sql(expression.this) # Order has a leading space + + return f"STRING_AGG({self.format_args(this, separator)}{order})" + + +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return f"{self.expressions(expression, flat=True)}[]" + return self.datatype_sql(expression) + + def _auto_increment_to_serial(expression): auto = expression.find(exp.AutoIncrementColumnConstraint) @@ -191,25 +212,27 @@ class Postgres(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, - "BY DEFAULT": TokenType.BY_DEFAULT, - "IDENTITY": TokenType.IDENTITY, - "GENERATED": TokenType.GENERATED, - "DOUBLE PRECISION": TokenType.DOUBLE, - "BIGSERIAL": TokenType.BIGSERIAL, - "SERIAL": TokenType.SERIAL, - "SMALLSERIAL": TokenType.SMALLSERIAL, - "UUID": TokenType.UUID, - "TEMP": TokenType.TEMPORARY, - "BEGIN TRANSACTION": TokenType.BEGIN, "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, + "BIGSERIAL": TokenType.BIGSERIAL, + "BY DEFAULT": TokenType.BY_DEFAULT, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, + "DOUBLE PRECISION": TokenType.DOUBLE, + "GENERATED": TokenType.GENERATED, + "GRANT": TokenType.COMMAND, + "HSTORE": TokenType.HSTORE, + "IDENTITY": TokenType.IDENTITY, + "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, "REVOKE": TokenType.COMMAND, - "GRANT": TokenType.COMMAND, + "SERIAL": TokenType.SERIAL, + "SMALLSERIAL": TokenType.SMALLSERIAL, + "TEMP": TokenType.TEMPORARY, + "UUID": TokenType.UUID, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } @@ -265,4 +288,7 @@ class Postgres(Dialect): exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", + exp.DataType: _datatype_sql, + exp.GroupConcat: _string_agg_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9d5cc11..1a09037 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -171,16 +171,7 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") - ROOT_PROPERTIES = { - exp.SchemaCommentProperty, - } - - WITH_PROPERTIES = { - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.AnonymousProperty, - exp.TableFormatProperty, - } + ROOT_PROPERTIES = {exp.SchemaCommentProperty} TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -231,7 +222,8 @@ class Presto(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", + exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.TimeStrToDate: _date_parse_sql, exp.TimeStrToTime: _date_parse_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a9b12fb..cd50979 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sqlglot import exp +from sqlglot import exp, transforms from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -18,12 +18,14 @@ class Redshift(Postgres): KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore + "COPY": TokenType.COMMAND, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, + "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, "SIMILAR TO": TokenType.SIMILAR_TO, } @@ -35,3 +37,17 @@ class Redshift(Postgres): exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } + + ROOT_PROPERTIES = { + exp.DistKeyProperty, + exp.SortKeyProperty, + exp.DistStyleProperty, + } + + TRANSFORMS = { + **Postgres.Generator.TRANSFORMS, # type: ignore + **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", + exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.DistStyleProperty: lambda self, e: self.naked_property(e), + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a96bd80..46155ff 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, inline_array_sql, rename_func, + var_map_sql, ) from sqlglot.expressions import Literal from sqlglot.helper import seq_get @@ -100,6 +101,14 @@ def _parse_date_part(self): return self.expression(exp.Extract, this=this, expression=expression) +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return "ARRAY" + elif expression.this == exp.DataType.Type.MAP: + return "OBJECT" + return self.datatype_sql(expression) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -142,6 +151,8 @@ class Snowflake(Dialect): "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, + "DECODE": exp.Matches.from_arg_list, + "OBJECT_CONSTRUCT": parser.parse_var_map, } FUNCTION_PARSERS = { @@ -195,16 +206,20 @@ class Snowflake(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), + exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Matches: rename_func("DECODE"), + exp.StrPosition: rename_func("POSITION"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Array: inline_array_sql, - exp.StrPosition: rename_func("POSITION"), - exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.UnixToTime: _unix_to_time_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 4e404b8..16083d1 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -98,7 +98,7 @@ class Spark(Hive): TRANSFORMS = { **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", + exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 87b98a5..bbb752b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType +# https://www.sqlite.org/lang_aggfunc.html#group_concat +def _group_concat_sql(self, expression): + this = expression.this + distinct = expression.find(exp.Distinct) + if distinct: + this = distinct.expressions[0] + distinct = "DISTINCT " + + if isinstance(expression.this, exp.Order): + self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") + if expression.this.this and not distinct: + this = expression.this.this + + separator = expression.args.get("separator") + return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -62,6 +79,7 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, + exp.GroupConcat: _group_concat_sql, } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index d3b83de..07ce38b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = { "mm": "%B", "m": "%B", } + DATE_DELTA_INTERVAL = { "year": "year", "yyyy": "year", @@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = { DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") + # N = Numeric, C=Currency TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} -def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): +def _format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( this=seq_get(args, 1), @@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def parse_format(args): +def _parse_format(args): fmt = seq_get(args, 1) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) if number_fmt: @@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e): return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" -def generate_format_sql(self, e): +def _format_sql(self, e): fmt = ( e.args["format"] if isinstance(e, exp.NumberToStr) @@ -87,6 +89,28 @@ def generate_format_sql(self, e): return f"FORMAT({self.format_args(e.this, fmt)})" +def _string_agg_sql(self, e): + e = e.copy() + + this = e.this + distinct = e.find(exp.Distinct) + if distinct: + # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression + self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") + this = distinct.expressions[0] + distinct.pop() + + order = "" + if isinstance(e.this, exp.Order): + if e.this.this: + this = e.this.this + e.this.this.pop() + order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space + + separator = e.args.get("separator") or exp.Literal.string(",") + return f"STRING_AGG({self.format_args(this, separator)}){order}" + + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" @@ -228,14 +252,14 @@ class TSQL(Dialect): "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), - "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), - "DATEPART": tsql_format_time_lambda(exp.TimeToStr), + "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), + "DATEPART": _format_time_lambda(exp.TimeToStr), "GETDATE": exp.CurrentDate.from_arg_list, "IIF": exp.If.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "FORMAT": parse_format, + "FORMAT": _parse_format, } VAR_LENGTH_DATATYPES = { @@ -298,6 +322,7 @@ class TSQL(Dialect): exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.If: rename_func("IIF"), - exp.NumberToStr: generate_format_sql, - exp.TimeToStr: generate_format_sql, + exp.NumberToStr: _format_sql, + exp.TimeToStr: _format_sql, + exp.GroupConcat: _string_agg_sql, } diff --git a/sqlglot/errors.py b/sqlglot/errors.py index 23a08bd..b5ef5ad 100644 --- a/sqlglot/errors.py +++ b/sqlglot/errors.py @@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError): class ParseError(SqlglotError): - pass + def __init__( + self, + message: str, + errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None, + ): + super().__init__(message) + self.errors = errors or [] + + @classmethod + def new( + cls, + message: str, + description: t.Optional[str] = None, + line: t.Optional[int] = None, + col: t.Optional[int] = None, + start_context: t.Optional[str] = None, + highlight: t.Optional[str] = None, + end_context: t.Optional[str] = None, + into_expression: t.Optional[str] = None, + ) -> ParseError: + return cls( + message, + [ + { + "description": description, + "line": line, + "col": col, + "start_context": start_context, + "highlight": highlight, + "end_context": end_context, + "into_expression": into_expression, + } + ], + ) class TokenError(SqlglotError): @@ -41,9 +74,13 @@ class ExecuteError(SqlglotError): pass -def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: +def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str: msg = [str(e) for e in errors[:maximum]] remaining = len(errors) - maximum if remaining > 0: msg.append(f"... and {remaining} more") return "\n\n".join(msg) + + +def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]: + return [e_dict for error in errors for e_dict in error.errors] diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index ed80cc9..e6cfcdd 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -122,7 +122,6 @@ def interval(this, unit): ENV = { - "__builtins__": {}, "exp": exp, # aggs "SUM": filter_nulls(sum), diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index cb2543c..908b80a 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -115,6 +115,9 @@ class PythonExecutor: sink = self.table(context.columns) for reader in table_iter: + if len(sink) >= step.limit: + break + if condition and not context.eval(condition): continue @@ -123,9 +126,6 @@ class PythonExecutor: else: sink.append(reader.row) - if len(sink) >= step.limit: - break - return self.context({step.name: sink}) def static(self): @@ -288,21 +288,32 @@ class PythonExecutor: end = 1 length = len(context.table) table = self.table(list(step.group) + step.aggregations) + condition = self.generate(step.condition) - for i in range(length): - context.set_index(i) - key = context.eval_tuple(group_by) - group = key if group is None else group - end += 1 - if key != group: - context.set_range(start, end - 2) - table.append(group + context.eval_tuple(aggregations)) - group = key - start = end - 2 - if i == length - 1: - context.set_range(start, end - 1) + def add_row(): + if not condition or context.eval(condition): table.append(group + context.eval_tuple(aggregations)) + if length: + for i in range(length): + context.set_index(i) + key = context.eval_tuple(group_by) + group = key if group is None else group + end += 1 + if key != group: + context.set_range(start, end - 2) + add_row() + group = key + start = end - 2 + if len(table.rows) >= step.limit: + break + if i == length - 1: + context.set_range(start, end - 1) + add_row() + elif step.limit > 0: + context.set_range(0, 0) + table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations)) + context = self.context({step.name: table, **{name: table for name in context.tables}}) if step.projections: @@ -311,11 +322,9 @@ class PythonExecutor: def sort(self, step, context): projections = self.generate_tuple(step.projections) - projection_columns = [p.alias_or_name for p in step.projections] all_columns = list(context.columns) + projection_columns sink = self.table(all_columns) - for reader, ctx in context: sink.append(reader.row + ctx.eval_tuple(projections)) @@ -401,8 +410,9 @@ class Python(Dialect): exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", - exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", + exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", exp.Is: lambda self, e: self.binary(e, "is"), exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index beafca8..96b32f1 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comment") + __slots__ = ("args", "parent", "arg_key", "type", "comments") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None - self.comment = None + self.comments = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -88,19 +88,6 @@ class Expression(metaclass=_Expression): return field.this return "" - def find_comment(self, key: str) -> str: - """ - Finds the comment that is attached to a specified child node. - - Args: - key: the key of the target child node (e.g. "this", "expression", etc). - - Returns: - The comment attached to the child node, or the empty string, if it doesn't exist. - """ - field = self.args.get(key) - return field.comment if isinstance(field, Expression) else "" - @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -137,7 +124,7 @@ class Expression(metaclass=_Expression): def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) - copy.comment = self.comment + copy.comments = self.comments copy.type = self.type return copy @@ -369,7 +356,7 @@ class Expression(metaclass=_Expression): ) for k, vs in self.args.items() } - args["comment"] = self.comment + args["comments"] = self.comments args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} @@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind): class PrimaryKeyColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"desc": False} class UniqueColumnConstraint(ColumnConstraintKind): @@ -819,6 +806,12 @@ class Unique(Expression): arg_types = {"expressions": True} +# https://www.postgresql.org/docs/9.1/sql-selectinto.html +# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples +class Into(Expression): + arg_types = {"this": True, "temporary": False, "unlogged": False} + + class From(Expression): arg_types = {"expressions": True} @@ -1065,67 +1058,67 @@ class Property(Expression): class TableFormatProperty(Property): - pass + arg_types = {"this": True} class PartitionedByProperty(Property): - pass + arg_types = {"this": True} class FileFormatProperty(Property): - pass + arg_types = {"this": True} class DistKeyProperty(Property): - pass + arg_types = {"this": True} class SortKeyProperty(Property): - pass + arg_types = {"this": True, "compound": False} class DistStyleProperty(Property): - pass + arg_types = {"this": True} + + +class LikeProperty(Property): + arg_types = {"this": True, "expressions": False} class LocationProperty(Property): - pass + arg_types = {"this": True} class EngineProperty(Property): - pass + arg_types = {"this": True} class AutoIncrementProperty(Property): - pass + arg_types = {"this": True} class CharacterSetProperty(Property): - arg_types = {"this": True, "value": True, "default": True} + arg_types = {"this": True, "default": True} class CollateProperty(Property): - pass + arg_types = {"this": True} class SchemaCommentProperty(Property): - pass - - -class AnonymousProperty(Property): - pass + arg_types = {"this": True} class ReturnsProperty(Property): - arg_types = {"this": True, "value": True, "is_table": False} + arg_types = {"this": True, "is_table": False} class LanguageProperty(Property): - pass + arg_types = {"this": True} class ExecuteAsProperty(Property): - pass + arg_types = {"this": True} class VolatilityProperty(Property): @@ -1135,27 +1128,36 @@ class VolatilityProperty(Property): class Properties(Expression): arg_types = {"expressions": True} - PROPERTY_KEY_MAPPING = { + NAME_TO_PROPERTY = { "AUTO_INCREMENT": AutoIncrementProperty, - "CHARACTER_SET": CharacterSetProperty, + "CHARACTER SET": CharacterSetProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, + "DISTKEY": DistKeyProperty, + "DISTSTYLE": DistStyleProperty, "ENGINE": EngineProperty, + "EXECUTE AS": ExecuteAsProperty, "FORMAT": FileFormatProperty, + "LANGUAGE": LanguageProperty, "LOCATION": LocationProperty, "PARTITIONED_BY": PartitionedByProperty, - "TABLE_FORMAT": TableFormatProperty, - "DISTKEY": DistKeyProperty, - "DISTSTYLE": DistStyleProperty, + "RETURNS": ReturnsProperty, "SORTKEY": SortKeyProperty, + "TABLE_FORMAT": TableFormatProperty, } + PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + @classmethod def from_dict(cls, properties_dict) -> Properties: expressions = [] for key, value in properties_dict.items(): - property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) - expressions.append(property_cls(this=Literal.string(key), value=convert(value))) + property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) + if property_cls: + expressions.append(property_cls(this=convert(value))) + else: + expressions.append(Property(this=Literal.string(key), value=convert(value))) + return cls(expressions=expressions) @@ -1383,6 +1385,7 @@ class Select(Subqueryable): "expressions": False, "hint": False, "distinct": False, + "into": False, "from": False, **QUERY_MODIFIERS, } @@ -2015,6 +2018,7 @@ class DataType(Expression): DECIMAL = auto() BOOLEAN = auto() JSON = auto() + JSONB = auto() INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() @@ -2029,6 +2033,7 @@ class DataType(Expression): STRUCT = auto() NULLABLE = auto() HLLSKETCH = auto() + HSTORE = auto() SUPER = auto() SERIAL = auto() SMALLSERIAL = auto() @@ -2109,7 +2114,7 @@ class Transaction(Command): class Commit(Command): - arg_types = {} # type: ignore + arg_types = {"chain": False} class Rollback(Command): @@ -2442,7 +2447,7 @@ class ArrayFilter(Func): class ArraySize(Func): - pass + arg_types = {"this": True, "expression": False} class ArraySort(Func): @@ -2726,6 +2731,16 @@ class VarMap(Func): is_var_len_args = True +class Matches(Func): + """Oracle/Snowflake decode. + https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm + Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else) + """ + + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + class Max(AggFunc): pass @@ -2785,6 +2800,10 @@ class Round(Func): arg_types = {"this": True, "decimals": False} +class RowNumber(Func): + arg_types: t.Dict[str, t.Any] = {} + + class SafeDivide(Func): arg_types = {"this": True, "expression": True} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ffb34eb..47774fc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,19 +1,16 @@ from __future__ import annotations import logging -import re import typing as t from sqlglot import exp -from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors +from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv from sqlglot.time import format_time from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") -NEWLINE_RE = re.compile("\r\n?|\n") - class Generator: """ @@ -58,11 +55,11 @@ class Generator: """ TRANSFORMS = { - exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -97,16 +94,17 @@ class Generator: exp.DistStyleProperty, exp.DistKeyProperty, exp.SortKeyProperty, + exp.LikeProperty, } WITH_PROPERTIES = { - exp.AnonymousProperty, + exp.Property, exp.FileFormatProperty, exp.PartitionedByProperty, exp.TableFormatProperty, } - WITH_SEPARATED_COMMENTS = (exp.Select,) + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) __slots__ = ( "time_mapping", @@ -211,7 +209,7 @@ class Generator: for msg in self.unsupported_messages: logger.warning(msg) elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) + raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) return sql @@ -226,25 +224,24 @@ class Generator: def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" - def maybe_comment(self, sql, expression, single_line=False): - comment = expression.comment if self._comments else None - - if not comment: - return sql - + def pad_comment(self, comment): comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment + return comment - if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"/*{comment}*/{self.sep()}{sql}" + def maybe_comment(self, sql, expression): + comments = expression.comments if self._comments else None - if not self.pretty: - return f"{sql} /*{comment}*/" + if not comments: + return sql + + sep = "\n" if self.pretty else " " + comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) - if not NEWLINE_RE.search(comment): - return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"{comments}{self.sep()}{sql}" - return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/" + return f"{sql} {comments}" def wrap(self, expression): this_sql = self.indent( @@ -387,8 +384,11 @@ class Generator: def notnullcolumnconstraint_sql(self, _): return "NOT NULL" - def primarykeycolumnconstraint_sql(self, _): - return "PRIMARY KEY" + def primarykeycolumnconstraint_sql(self, expression): + desc = expression.args.get("desc") + if desc is not None: + return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" + return f"PRIMARY KEY" def uniquecolumnconstraint_sql(self, _): return "UNIQUE" @@ -546,36 +546,33 @@ class Generator: def root_properties(self, properties): if properties.expressions: - return self.sep() + self.expressions( - properties, - indent=False, - sep=" ", - ) + return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" def properties(self, properties, prefix="", sep=", "): if properties.expressions: - expressions = self.expressions( - properties, - sep=sep, - indent=False, - ) + expressions = self.expressions(properties, sep=sep, indent=False) return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" return "" def with_properties(self, properties): - return self.properties( - properties, - prefix="WITH", - ) + return self.properties(properties, prefix="WITH") def property_sql(self, expression): - if isinstance(expression.this, exp.Literal): - key = expression.this.this - else: - key = expression.name - value = self.sql(expression, "value") - return f"{key}={value}" + property_cls = expression.__class__ + if property_cls == exp.Property: + return f"{expression.name}={self.sql(expression, 'value')}" + + property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) + if not property_name: + self.unsupported(f"Unsupported property {property_name}") + + return f"{property_name}={self.sql(expression, 'this')}" + + def likeproperty_sql(self, expression): + options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) + options = f" {options}" if options else "" + return f"LIKE {self.sql(expression, 'this')}{options}" def insert_sql(self, expression): overwrite = expression.args.get("overwrite") @@ -700,6 +697,11 @@ class Generator: def var_sql(self, expression): return self.sql(expression, "this") + def into_sql(self, expression): + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" + return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" + def from_sql(self, expression): expressions = self.expressions(expression, flat=True) return f"{self.seg('FROM')} {expressions}" @@ -883,6 +885,7 @@ class Generator: sql = self.query_modifiers( expression, f"SELECT{hint}{distinct}{expressions}", + self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) return self.prepend_ctes(expression, sql) @@ -1061,6 +1064,11 @@ class Generator: else: return f"TRIM({target})" + def concat_sql(self, expression): + if len(expression.expressions) == 1: + return self.sql(expression.expressions[0]) + return self.function_fallback_sql(expression) + def check_sql(self, expression): this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -1125,7 +1133,10 @@ class Generator: return self.prepend_ctes(expression, sql) def neg_sql(self, expression): - return f"-{self.sql(expression, 'this')}" + # This makes sure we don't convert "- - 5" to "--5", which is a comment + this_sql = self.sql(expression, "this") + sep = " " if this_sql[0] == "-" else "" + return f"-{sep}{this_sql}" def not_sql(self, expression): return f"NOT {self.sql(expression, 'this')}" @@ -1191,8 +1202,12 @@ class Generator: def transaction_sql(self, *_): return "BEGIN" - def commit_sql(self, *_): - return "COMMIT" + def commit_sql(self, expression): + chain = expression.args.get("chain") + if chain is not None: + chain = " AND CHAIN" if chain else " AND NO CHAIN" + + return f"COMMIT{chain or ''}" def rollback_sql(self, expression): savepoint = expression.args.get("savepoint") @@ -1334,15 +1349,15 @@ class Generator: result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) - comment = self.maybe_comment("", e, single_line=True) + comments = self.maybe_comment("", e) if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") else: - result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sqls, skip_first=False) if indent else result_sqls @@ -1354,7 +1369,10 @@ class Generator: return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" def naked_property(self, expression): - return f"{expression.name} {self.sql(expression, 'value')}" + property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) + if not property_name: + self.unsupported(f"Unsupported property {expression.__class__.__name__}") + return f"{property_name} {self.sql(expression, 'this')}" def set_operation(self, expression, op): this = self.sql(expression, "this") diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 8704e90..39e252c 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -68,6 +68,9 @@ def eliminate_subqueries(expression): for cte_scope in root.cte_scopes: # Append all the new CTEs from this existing CTE for scope in cte_scope.traverse(): + if scope is cte_scope: + # Don't try to eliminate this CTE itself + continue new_cte = _eliminate(scope, existing_ctes, taken) if new_cte: new_ctes.append(new_cte) @@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) + if scope.is_cte: + return _eliminate_cte(scope, existing_ctes, taken) + def _eliminate_union(scope, existing_ctes, taken): duplicate_cte_alias = existing_ctes.get(scope.expression) @@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + table = exp.alias_(exp.table_(name), alias=parent.alias or name) + parent.replace(table) + + return cte + + +def _eliminate_cte(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + with_ = parent.parent + parent.pop() + if not with_.expressions: + with_.pop() + + # Rename references to this CTE + for child_scope in scope.parent.traverse(): + for table, source in child_scope.selected_sources.values(): + if source is scope: + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + table.replace(new_table) + + return cte + + +def _new_cte(scope, existing_ctes, taken): + """ + Returns: + tuple of (name, cte) + where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. + If this CTE duplicates an existing CTE, `cte` will be None. + """ duplicate_cte_alias = existing_ctes.get(scope.expression) parent = scope.expression.parent - name = alias = parent.alias + name = parent.alias - if not alias: - name = alias = find_new_name(taken=taken, base="cte") + if not name: + name = find_new_name(taken=taken, base="cte") if duplicate_cte_alias: name = duplicate_cte_alias - elif taken.get(alias): - name = find_new_name(taken=taken, base=alias) + elif taken.get(name): + name = find_new_name(taken=taken, base=name) taken[name] = scope - table = exp.alias_(exp.table_(name), alias=alias) - parent.replace(table) - if not duplicate_cte_alias: existing_ctes[scope.expression] = name - return exp.CTE( + cte = exp.CTE( this=scope.expression, alias=exp.TableAlias(this=exp.to_identifier(name)), ) + else: + cte = None + return name, cte diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py new file mode 100644 index 0000000..1cc76cf --- /dev/null +++ b/sqlglot/optimizer/lower_identities.py @@ -0,0 +1,92 @@ +from sqlglot import exp +from sqlglot.helper import ensure_collection + + +def lower_identities(expression): + """ + Convert all unquoted identifiers to lower case. + + Assuming the schema is all lower case, this essentially makes identifiers case-insensitive. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> lower_identities(expression).sql() + 'SELECT bar.a AS A FROM "Foo".bar' + + Args: + expression (sqlglot.Expression): expression to quote + Returns: + sqlglot.Expression: quoted expression + """ + # We need to leave the output aliases unchanged, so the selects need special handling + _lower_selects(expression) + + # These clauses can reference output aliases and also need special handling + _lower_order(expression) + _lower_having(expression) + + # We've already handled these args, so don't traverse into them + traversed = {"expressions", "order", "having"} + + if isinstance(expression, exp.Subquery): + # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1 + lower_identities(expression.this) + traversed |= {"this"} + + if isinstance(expression, exp.Union): + # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X + lower_identities(expression.left) + lower_identities(expression.right) + traversed |= {"this", "expression"} + + for k, v in expression.args.items(): + if k in traversed: + continue + + for child in ensure_collection(v): + if isinstance(child, exp.Expression): + child.transform(_lower, copy=False) + + return expression + + +def _lower_selects(expression): + for e in expression.expressions: + # Leave output aliases as-is + e.unalias().transform(_lower, copy=False) + + +def _lower_order(expression): + order = expression.args.get("order") + + if not order: + return + + output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)} + + for ordered in order.expressions: + # Don't lower references to output aliases + if not ( + isinstance(ordered.this, exp.Column) + and not ordered.this.table + and ordered.this.name in output_aliases + ): + ordered.transform(_lower, copy=False) + + +def _lower_having(expression): + having = expression.args.get("having") + + if not having: + return + + # Don't lower references to output aliases + for agg in having.find_all(exp.AggFunc): + agg.transform(_lower, copy=False) + + +def _lower(node): + if isinstance(node, exp.Identifier) and not node.quoted: + node.set("this", node.this.lower()) + return node diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d0e38cd..6819717 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins @@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries RULES = ( + lower_identities, qualify_tables, isolate_table_selects, qualify_columns, diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index dbd680b..2046917 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,16 +1,15 @@ import itertools from sqlglot import exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import ScopeType, traverse_scope def unnest_subqueries(expression): """ Rewrite sqlglot AST to convert some predicates with subqueries into joins. - Convert the subquery into a group by so it is not a many to many left join. - Unnesting can only occur if the subquery does not have LIMIT or OFFSET. - Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + Convert scalar subqueries into cross joins. + Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. Example: >>> import sqlglot @@ -29,21 +28,43 @@ def unnest_subqueries(expression): for scope in traverse_scope(expression): select = scope.expression parent = select.parent_select + if not parent: + continue if scope.external_columns: decorrelate(select, parent, scope.external_columns, sequence) - else: + elif scope.scope_type == ScopeType.SUBQUERY: unnest(select, parent, sequence) return expression def unnest(select, parent_select, sequence): - predicate = select.find_ancestor(exp.In, exp.Any) + if len(select.selects) > 1: + return + + predicate = select.find_ancestor(exp.Condition) + alias = _alias(sequence) if not predicate or parent_select is not predicate.parent_select: return - if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + # this subquery returns a scalar and can just be converted to a cross join + if not isinstance(predicate, (exp.In, exp.Any)): + having = predicate.find_ancestor(exp.Having) + column = exp.column(select.selects[0].alias_or_name, alias) + if having and having.parent_select is parent_select: + column = exp.Max(this=column) + _replace(select.parent, column) + + parent_select.join( + select, + join_type="CROSS", + join_alias=alias, + copy=False, + ) + return + + if select.find(exp.Limit, exp.Offset): return if isinstance(predicate, exp.Any): @@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence): column = _other_operand(predicate) value = select.selects[0] - alias = _alias(sequence) on = exp.condition(f'{column} = "{alias}"."{value.alias}"') _replace(predicate, f"NOT {on.right} IS NULL") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5b93510..bdf0d2d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4,7 +4,7 @@ import logging import typing as t from sqlglot import exp -from sqlglot.errors import ErrorLevel, ParseError, concat_errors +from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -104,6 +104,7 @@ class Parser(metaclass=_Parser): TokenType.BINARY, TokenType.VARBINARY, TokenType.JSON, + TokenType.JSONB, TokenType.INTERVAL, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, @@ -115,6 +116,7 @@ class Parser(metaclass=_Parser): TokenType.GEOGRAPHY, TokenType.GEOMETRY, TokenType.HLLSKETCH, + TokenType.HSTORE, TokenType.SUPER, TokenType.SERIAL, TokenType.SMALLSERIAL, @@ -153,6 +155,7 @@ class Parser(metaclass=_Parser): TokenType.COLLATE, TokenType.COMMAND, TokenType.COMMIT, + TokenType.COMPOUND, TokenType.CONSTRAINT, TokenType.CURRENT_TIME, TokenType.DEFAULT, @@ -194,6 +197,7 @@ class Parser(metaclass=_Parser): TokenType.RANGE, TokenType.REFERENCES, TokenType.RETURNS, + TokenType.ROW, TokenType.ROWS, TokenType.SCHEMA, TokenType.SCHEMA_COMMENT, @@ -213,6 +217,7 @@ class Parser(metaclass=_Parser): TokenType.TRUE, TokenType.UNBOUNDED, TokenType.UNIQUE, + TokenType.UNLOGGED, TokenType.UNPIVOT, TokenType.PROPERTIES, TokenType.PROCEDURE, @@ -400,9 +405,17 @@ class Parser(metaclass=_Parser): TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), TokenType.BEGIN: lambda self: self._parse_transaction(), TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), } + UNARY_PARSERS = { + TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op + TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()), + TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()), + TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()), + } + PRIMARY_PARSERS = { TokenType.STRING: lambda self, token: self.expression( exp.Literal, this=token.text, is_string=True @@ -446,19 +459,20 @@ class Parser(metaclass=_Parser): } PROPERTY_PARSERS = { - TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), - TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), - TokenType.LOCATION: lambda self: self.expression( - exp.LocationProperty, - this=exp.Literal.string("LOCATION"), - value=self._parse_string(), + TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment( + exp.AutoIncrementProperty ), + TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), + TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty), TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), - TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), - TokenType.STORED: lambda self: self._parse_stored(), + TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment( + exp.SchemaCommentProperty + ), + TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty), TokenType.DISTKEY: lambda self: self._parse_distkey(), - TokenType.DISTSTYLE: lambda self: self._parse_diststyle(), + TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty), TokenType.SORTKEY: lambda self: self._parse_sortkey(), + TokenType.LIKE: lambda self: self._parse_create_like(), TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), @@ -468,7 +482,7 @@ class Parser(metaclass=_Parser): ), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), - TokenType.EXECUTE: lambda self: self._parse_execute_as(), + TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), TokenType.DETERMINISTIC: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), @@ -489,6 +503,7 @@ class Parser(metaclass=_Parser): ), TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), TokenType.UNIQUE: lambda self: self._parse_unique(), + TokenType.LIKE: lambda self: self._parse_create_like(), } NO_PAREN_FUNCTION_PARSERS = { @@ -505,6 +520,7 @@ class Parser(metaclass=_Parser): "TRIM": lambda self: self._parse_trim(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "TRY_CAST": lambda self: self._parse_cast(False), + "STRING_AGG": lambda self: self._parse_string_agg(), } QUERY_MODIFIER_PARSERS = { @@ -556,7 +572,7 @@ class Parser(metaclass=_Parser): "_curr", "_next", "_prev", - "_prev_comment", + "_prev_comments", "_show_trie", "_set_trie", ) @@ -589,7 +605,7 @@ class Parser(metaclass=_Parser): self._curr = None self._next = None self._prev = None - self._prev_comment = None + self._prev_comments = None def parse(self, raw_tokens, sql=None): """ @@ -608,6 +624,7 @@ class Parser(metaclass=_Parser): ) def parse_into(self, expression_types, raw_tokens, sql=None): + errors = [] for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: @@ -615,8 +632,12 @@ class Parser(metaclass=_Parser): try: return self._parse(parser, raw_tokens, sql) except ParseError as e: - error = e - raise ParseError(f"Failed to parse into {expression_types}") from error + e.errors[0]["into_expression"] = expression_type + errors.append(e) + raise ParseError( + f"Failed to parse into {expression_types}", + errors=merge_errors(errors), + ) from errors[-1] def _parse(self, parse_method, raw_tokens, sql=None): self.reset() @@ -650,7 +671,10 @@ class Parser(metaclass=_Parser): for error in self.errors: logger.error(str(error)) elif self.error_level == ErrorLevel.RAISE and self.errors: - raise ParseError(concat_errors(self.errors, self.max_errors)) + raise ParseError( + concat_messages(self.errors, self.max_errors), + errors=merge_errors(self.errors), + ) def raise_error(self, message, token=None): token = token or self._curr or self._prev or Token.string("") @@ -659,19 +683,27 @@ class Parser(metaclass=_Parser): start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] - error = ParseError( + error = ParseError.new( f"{message}. Line {token.line}, Col: {token.col}.\n" - f" {start_context}\033[4m{highlight}\033[0m{end_context}" + f" {start_context}\033[4m{highlight}\033[0m{end_context}", + description=message, + line=token.line, + col=token.col, + start_context=start_context, + highlight=highlight, + end_context=end_context, ) if self.error_level == ErrorLevel.IMMEDIATE: raise error self.errors.append(error) - def expression(self, exp_class, **kwargs): + def expression(self, exp_class, comments=None, **kwargs): instance = exp_class(**kwargs) - if self._prev_comment: - instance.comment = self._prev_comment - self._prev_comment = None + if self._prev_comments: + instance.comments = self._prev_comments + self._prev_comments = None + if comments: + instance.comments = comments self.validate_expression(instance) return instance @@ -714,10 +746,10 @@ class Parser(metaclass=_Parser): self._next = seq_get(self._tokens, self._index + 1) if self._index > 0: self._prev = self._tokens[self._index - 1] - self._prev_comment = self._prev.comment + self._prev_comments = self._prev.comments else: self._prev = None - self._prev_comment = None + self._prev_comments = None def _retreat(self, index): self._advance(index - self._index) @@ -768,7 +800,7 @@ class Parser(metaclass=_Parser): ) def _parse_create(self): - replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) + replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) transient = self._match(TokenType.TRANSIENT) unique = self._match(TokenType.UNIQUE) @@ -822,97 +854,57 @@ class Parser(metaclass=_Parser): def _parse_property(self): if self._match_set(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.token_type](self) + if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): return self._parse_character_set(True) + if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): + return self._parse_sortkey(compound=True) + if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): - key = self._parse_var().this + key = self._parse_var() self._match(TokenType.EQ) - - return self.expression( - exp.AnonymousProperty, - this=exp.Literal.string(key), - value=self._parse_column(), - ) + return self.expression(exp.Property, this=key, value=self._parse_column()) return None def _parse_property_assignment(self, exp_class): - prop = self._prev.text self._match(TokenType.EQ) - return self.expression(exp_class, this=prop, value=self._parse_var_or_string()) + self._match(TokenType.ALIAS) + return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number()) def _parse_partitioned_by(self): self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, - this=exp.Literal.string("PARTITIONED_BY"), - value=self._parse_schema() or self._parse_bracket(self._parse_field()), - ) - - def _parse_stored(self): - self._match(TokenType.ALIAS) - self._match(TokenType.EQ) - return self.expression( - exp.FileFormatProperty, - this=exp.Literal.string("FORMAT"), - value=exp.Literal.string(self._parse_var_or_string().name), + this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) def _parse_distkey(self): - self._match_l_paren() - this = exp.Literal.string("DISTKEY") - value = exp.Literal.string(self._parse_var().name) - self._match_r_paren() - return self.expression( - exp.DistKeyProperty, - this=this, - value=value, - ) + return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var)) - def _parse_sortkey(self): - self._match_l_paren() - this = exp.Literal.string("SORTKEY") - value = exp.Literal.string(self._parse_var().name) - self._match_r_paren() - return self.expression( - exp.SortKeyProperty, - this=this, - value=value, - ) - - def _parse_diststyle(self): - this = exp.Literal.string("DISTSTYLE") - value = exp.Literal.string(self._parse_var().name) - return self.expression( - exp.DistStyleProperty, - this=this, - value=value, - ) - - def _parse_auto_increment(self): - self._match(TokenType.EQ) - return self.expression( - exp.AutoIncrementProperty, - this=exp.Literal.string("AUTO_INCREMENT"), - value=self._parse_number(), - ) + def _parse_create_like(self): + table = self._parse_table(schema=True) + options = [] + while self._match_texts(("INCLUDING", "EXCLUDING")): + options.append( + self.expression( + exp.Property, + this=self._prev.text.upper(), + value=exp.Var(this=self._parse_id_var().this.upper()), + ) + ) + return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_schema_comment(self): - self._match(TokenType.EQ) + def _parse_sortkey(self, compound=False): return self.expression( - exp.SchemaCommentProperty, - this=exp.Literal.string("COMMENT"), - value=self._parse_string(), + exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound ) def _parse_character_set(self, default=False): self._match(TokenType.EQ) return self.expression( - exp.CharacterSetProperty, - this=exp.Literal.string("CHARACTER_SET"), - value=self._parse_var_or_string(), - default=default, + exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) def _parse_returns(self): @@ -931,20 +923,7 @@ class Parser(metaclass=_Parser): else: value = self._parse_types() - return self.expression( - exp.ReturnsProperty, - this=exp.Literal.string("RETURNS"), - value=value, - is_table=is_table, - ) - - def _parse_execute_as(self): - self._match(TokenType.ALIAS) - return self.expression( - exp.ExecuteAsProperty, - this=exp.Literal.string("EXECUTE AS"), - value=self._parse_var(), - ) + return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) def _parse_properties(self): properties = [] @@ -956,7 +935,7 @@ class Parser(metaclass=_Parser): properties.extend( self._parse_wrapped_csv( lambda: self.expression( - exp.AnonymousProperty, + exp.Property, this=self._parse_string(), value=self._match(TokenType.EQ) and self._parse_string(), ) @@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser): options = [] if self._match(TokenType.OPTIONS): - options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ) + self._match_l_paren() + k = self._parse_string() + self._match(TokenType.EQ) + v = self._parse_string() + options = [k, v] + self._match_r_paren() self._match(TokenType.ALIAS) return self.expression( @@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser): self.raise_error(f"{this.key} does not support CTE") this = cte elif self._match(TokenType.SELECT): - comment = self._prev_comment + comments = self._prev_comments hint = self._parse_hint() all_ = self._match(TokenType.ALL) @@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser): expressions=expressions, limit=limit, ) - this.comment = comment + this.comments = comments + + into = self._parse_into() + if into: + this.set("into", into) + from_ = self._parse_from() if from_: this.set("from", from_) + self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) @@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser): return self.expression(exp.Hint, expressions=hints) return None + def _parse_into(self): + if not self._match(TokenType.INTO): + return None + + temp = self._match(TokenType.TEMPORARY) + unlogged = self._match(TokenType.UNLOGGED) + self._match(TokenType.TABLE) + + return self.expression( + exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged + ) + def _parse_from(self): if not self._match(TokenType.FROM): return None - - return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) + return self.expression( + exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) + ) def _parse_lateral(self): outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) @@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser): def _parse_where(self, skip_where_token=False): if not skip_where_token and not self._match(TokenType.WHERE): return None - return self.expression(exp.Where, this=self._parse_conjunction()) + return self.expression( + exp.Where, comments=self._prev_comments, this=self._parse_conjunction() + ) def _parse_group(self, skip_group_by_token=False): if not skip_group_by_token and not self._match(TokenType.GROUP_BY): @@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser): return self._parse_tokens(self._parse_unary, self.FACTOR) def _parse_unary(self): - if self._match(TokenType.NOT): - return self.expression(exp.Not, this=self._parse_equality()) - if self._match(TokenType.TILDA): - return self.expression(exp.BitwiseNot, this=self._parse_unary()) - if self._match(TokenType.DASH): - return self.expression(exp.Neg, this=self._parse_unary()) + if self._match_set(self.UNARY_PARSERS): + return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) def _parse_type(self): @@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser): expressions = None maybe_func = False - if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - return exp.DataType( - this=exp.DataType.Type.ARRAY, - expressions=[exp.DataType.build(type_token.value)], - nested=True, - ) - - if self._match(TokenType.L_BRACKET): - self._retreat(index) - return None - if self._match(TokenType.L_PAREN): if is_struct: expressions = self._parse_csv(self._parse_struct_kwargs) @@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser): self._match_r_paren() maybe_func = True + if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + return exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(type_token.value, expressions=expressions)], + nested=True, + ) + + if self._match(TokenType.L_BRACKET): + self._retreat(index) + return None + if nested and self._match(TokenType.LT): if is_struct: expressions = self._parse_csv(self._parse_struct_kwargs) @@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser): return exp.Literal.number(f"0.{self._prev.text}") if self._match(TokenType.L_PAREN): - comment = self._prev_comment + comments = self._prev_comments query = self._parse_select() if query: @@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Tuple, expressions=expressions) else: this = self.expression(exp.Paren, this=this) - if comment: - this.comment = comment + if comments: + this.comments = comments return this return None @@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.SCHEMA_COMMENT): kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): - kind = exp.PrimaryKeyColumnConstraint() + desc = None + if self._match(TokenType.ASC) or self._match(TokenType.DESC): + desc = self._prev.token_type == TokenType.DESC + kind = exp.PrimaryKeyColumnConstraint(desc=desc) elif self._match(TokenType.UNIQUE): kind = exp.UniqueColumnConstraint() elif self._match(TokenType.GENERATED): @@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") - this.comment = self._prev_comment + this.comments = self._prev_comments return self._parse_bracket(this) def _parse_case(self): @@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_string_agg(self): + if self._match(TokenType.DISTINCT): + args = self._parse_csv(self._parse_conjunction) + expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) + else: + args = self._parse_csv(self._parse_conjunction) + expression = seq_get(args, 0) + + index = self._index + if not self._match(TokenType.R_PAREN): + # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) + order = self._parse_order(this=expression) + return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) + + # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY [ASC | DESC]). + # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that + # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. + if not self._match(TokenType.WITHIN_GROUP): + self._retreat(index) + this = exp.GroupConcat.from_arg_list(args) + self.validate_expression(this, args) + return this + + self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller) + order = self._parse_order(this=expression) + return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) + def _parse_convert(self, strict): this = self._parse_column() if self._match(TokenType.USING): @@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser): items = [parse_result] if parse_result is not None else [] while self._match(sep): - if parse_result and self._prev_comment is not None: - parse_result.comment = self._prev_comment + if parse_result and self._prev_comments: + parse_result.comments = self._prev_comments parse_result = parse_method() if parse_result is not None: @@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser): while self._match_set(expressions): this = self.expression( - expressions[self._prev.token_type], this=this, expression=parse_method() + expressions[self._prev.token_type], + this=this, + comments=self._prev_comments, + expression=parse_method(), ) return this @@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) def _parse_commit_or_rollback(self): + chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser): self._match_text_seq("SAVEPOINT") savepoint = self._parse_id_var() + if self._match(TokenType.AND): + chain = not self._match_text_seq("NO") + self._match_text_seq("CHAIN") + if is_rollback: return self.expression(exp.Rollback, savepoint=savepoint) - return self.expression(exp.Commit) + return self.expression(exp.Commit, chain=chain) def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) @@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser): def _match_l_paren(self, expression=None): if not self._match(TokenType.L_PAREN): self.raise_error("Expecting (") - if expression and self._prev_comment: - expression.comment = self._prev_comment + if expression and self._prev_comments: + expression.comments = self._prev_comments def _match_r_paren(self, expression=None): if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") - if expression and self._prev_comment: - expression.comment = self._prev_comment + if expression and self._prev_comments: + expression.comments = self._prev_comments def _match_texts(self, texts): if self._curr and self._curr.text.upper() in texts: diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 51db2d4..4967231 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -130,18 +130,20 @@ class Step: aggregations = [] sequence = itertools.count() - for e in expression.expressions: - aggregation = e.find(exp.AggFunc) - - if aggregation: - projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) - aggregations.append(e) - for operand in aggregation.unnest_operands(): + def extract_agg_operands(expression): + for agg in expression.find_all(exp.AggFunc): + for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" operand.replace(exp.column(operands[operand], quoted=True)) + + for e in expression.expressions: + if e.find(exp.AggFunc): + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + aggregations.append(e) + extract_agg_operands(e) else: projections.append(e) @@ -156,6 +158,13 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name + + having = expression.args.get("having") + + if having: + extract_agg_operands(having) + aggregate.condition = having.this + aggregate.operands = tuple( alias(operand, alias_) for operand, alias_ in operands.items() ) @@ -172,11 +181,6 @@ class Step: aggregate.add_dependency(step) step = aggregate - having = expression.args.get("having") - - if having: - step.condition = having.this - order = expression.args.get("order") if order: @@ -188,6 +192,17 @@ class Step: step.projections = projections + if isinstance(expression, exp.Select) and expression.args.get("distinct"): + distinct = Aggregate() + distinct.source = step.name + distinct.name = step.name + distinct.group = { + e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name) + for e in projections or expression.expressions + } + distinct.add_dependency(step) + step = distinct + limit = expression.args.get("limit") if limit: @@ -231,6 +246,9 @@ class Step: if self.condition: lines.append(f"{nested}Condition: {self.condition.sql()}") + if self.limit is not math.inf: + lines.append(f"{nested}Limit: {self.limit}") + if self.dependencies: lines.append(f"{nested}Dependencies:") for dependency in self.dependencies: @@ -258,12 +276,7 @@ class Scan(Step): cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None ) -> Step: table = expression - alias_ = expression.alias - - if not alias_: - raise UnsupportedError( - "Tables/Subqueries must be aliased. Run it through the optimizer" - ) + alias_ = expression.alias_or_name if isinstance(expression, exp.Subquery): table = expression.this @@ -338,6 +351,9 @@ class Aggregate(Step): lines.append(f"{indent}Group:") for expression in self.group.values(): lines.append(f"{indent} - {expression.sql()}") + if self.condition: + lines.append(f"{indent}Having:") + lines.append(f"{indent} - {self.condition.sql()}") if self.operands: lines.append(f"{indent}Operands:") for expression in self.operands: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index ec8cd91..8a7a38e 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -81,6 +81,7 @@ class TokenType(AutoName): BINARY = auto() VARBINARY = auto() JSON = auto() + JSONB = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -91,6 +92,7 @@ class TokenType(AutoName): NULLABLE = auto() GEOMETRY = auto() HLLSKETCH = auto() + HSTORE = auto() SUPER = auto() SERIAL = auto() SMALLSERIAL = auto() @@ -113,6 +115,7 @@ class TokenType(AutoName): APPLY = auto() ARRAY = auto() ASC = auto() + ASOF = auto() AT_TIME_ZONE = auto() AUTO_INCREMENT = auto() BEGIN = auto() @@ -130,6 +133,7 @@ class TokenType(AutoName): COMMAND = auto() COMMENT = auto() COMMIT = auto() + COMPOUND = auto() CONSTRAINT = auto() CREATE = auto() CROSS = auto() @@ -271,6 +275,7 @@ class TokenType(AutoName): UNBOUNDED = auto() UNCACHE = auto() UNION = auto() + UNLOGGED = auto() UNNEST = auto() UNPIVOT = auto() UPDATE = auto() @@ -291,7 +296,7 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col", "comment") + __slots__ = ("token_type", "text", "line", "col", "comments") @classmethod def number(cls, number: int) -> Token: @@ -319,13 +324,13 @@ class Token: text: str, line: int = 1, col: int = 1, - comment: t.Optional[str] = None, + comments: t.List[str] = [], ) -> None: self.token_type = token_type self.text = text self.line = line self.col = max(col - len(text), 1) - self.comment = comment + self.comments = comments def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) @@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer): "COLLATE": TokenType.COLLATE, "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, + "COMPOUND": TokenType.COMPOUND, "CONSTRAINT": TokenType.CONSTRAINT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, @@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer): "TRAILING": TokenType.TRAILING, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, - "UNPIVOT": TokenType.UNPIVOT, + "UNLOGGED": TokenType.UNLOGGED, "UNNEST": TokenType.UNNEST, + "UNPIVOT": TokenType.UNPIVOT, "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, @@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer): "_current", "_line", "_col", - "_comment", + "_comments", "_char", "_end", "_peek", "_prev_token_line", - "_prev_token_comment", + "_prev_token_comments", "_prev_token_type", "_replace_backslash", ) @@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer): self._current = 0 self._line = 1 self._col = 1 - self._comment = None + self._comments: t.List[str] = [] self._char = None self._end = None self._peek = None self._prev_token_line = -1 - self._prev_token_comment = None + self._prev_token_comments: t.List[str] = [] self._prev_token_type = None def tokenize(self, sql: str) -> t.List[Token]: @@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer): def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line - self._prev_token_comment = self._comment + self._prev_token_comments = self._comments self._prev_token_type = token_type # type: ignore self.tokens.append( Token( @@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer): self._text if text is None else text, self._line, self._col, - self._comment, + self._comments, ) ) - self._comment = None + self._comments = [] if token_type in self.COMMANDS and ( len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON @@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer): while not self._end and self._chars(comment_end_size) != comment_end: self._advance() - self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore + self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore self._advance(comment_end_size - 1) else: while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore self._advance() - self._comment = self._text[comment_start_size:] # type: ignore - - # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both - # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one. + self._comments.append(self._text[comment_start_size:]) # type: ignore + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. + # Multiple consecutive comments are preserved by appending them to the current comments list. if comment_start_line == self._prev_token_line: - if self._prev_token_comment is None: - self.tokens[-1].comment = self._comment - self._prev_token_comment = self._comment - - self._comment = None + self.tokens[-1].comments.extend(self._comments) + self._comments = [] return True diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 412b881..99949a1 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -2,6 +2,8 @@ from __future__ import annotations import typing as t +from sqlglot.helper import find_new_name + if t.TYPE_CHECKING: from sqlglot.generator import Generator @@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression: return expression +def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: + """ + Convert SELECT DISTINCT ON statements to a subquery with a window function. + + This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. + + Args: + expression: the expression that will be transformed. + + Returns: + The transformed expression. + """ + if ( + isinstance(expression, exp.Select) + and expression.args.get("distinct") + and expression.args["distinct"].args.get("on") + and isinstance(expression.args["distinct"].args["on"], exp.Tuple) + ): + distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions] + outer_selects = [e.copy() for e in expression.expressions] + nested = expression.copy() + nested.args["distinct"].pop() + row_number = find_new_name(expression.named_selects, "_row_number") + window = exp.Window( + this=exp.RowNumber(), + partition_by=distinct_cols, + ) + order = nested.args.get("order") + if order: + window.set("order", order.copy()) + order.pop() + window = exp.alias_(window, row_number) + nested.select(window, copy=False) + return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1') + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], to_sql: t.Callable[[Generator, exp.Expression], str], @@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable: UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} +ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 8e5e5cd..99b140d 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -1276,7 +1276,7 @@ class TestFunctions(unittest.TestCase): col = SF.concat(SF.col("cola"), SF.col("colb")) self.assertEqual("CONCAT(cola, colb)", col.sql()) col_single = SF.concat("cola") - self.assertEqual("CONCAT(cola)", col_single.sql()) + self.assertEqual("cola", col_single.sql()) def test_array_position(self): col_str = SF.array_position("cola", SF.col("colb")) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index efb41bb..c95c967 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -10,6 +10,10 @@ class TestClickhouse(Validator): self.validate_identity("SELECT * FROM x AS y FINAL") self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))") self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))") + self.validate_identity("SELECT * FROM foo LEFT ANY JOIN bla") + self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla") + self.validate_identity("SELECT * FROM foo ASOF JOIN bla") + self.validate_identity("SELECT * FROM foo ANY JOIN bla") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 1b2f9c1..6033570 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -997,6 +997,13 @@ class TestDialect(Validator): "spark": "CONCAT_WS('-', x)", }, ) + self.validate_all( + "CONCAT(a)", + write={ + "mysql": "a", + "tsql": "a", + }, + ) self.validate_all( "IF(x > 1, 1, 0)", write={ @@ -1263,8 +1270,8 @@ class TestDialect(Validator): self.validate_all( """/* comment1 */ SELECT - x, -- comment2 - y -- comment3""", + x, /* comment2 */ + y /* comment3 */""", read={ "mysql": """SELECT # comment1 x, # comment2 diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 625156b..99b0493 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -89,6 +89,8 @@ class TestDuckDB(Validator): "presto": "CAST(COL AS ARRAY(BIGINT))", "hive": "CAST(COL AS ARRAY)", "spark": "CAST(COL AS ARRAY)", + "postgres": "CAST(COL AS BIGINT[])", + "snowflake": "CAST(COL AS ARRAY)", }, ) @@ -104,6 +106,10 @@ class TestDuckDB(Validator): "spark": "ARRAY(0, 1, 2)", }, ) + self.validate_all( + "SELECT ARRAY_LENGTH([0], 1) AS x", + write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"}, + ) self.validate_all( "REGEXP_MATCHES(x, y)", write={ diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 69c7630..22d7bce 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -139,7 +139,7 @@ class TestHive(Validator): "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", write={ "duckdb": "CREATE TABLE test AS SELECT 1", - "presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET', x='1', Z='2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", }, @@ -459,6 +459,7 @@ class TestHive(Validator): "hive": "MAP(a, b, c, d)", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "spark": "MAP(a, b, c, d)", + "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", }, write={ "": "MAP(ARRAY(a, c), ARRAY(b, d))", @@ -467,6 +468,7 @@ class TestHive(Validator): "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "hive": "MAP(a, b, c, d)", "spark": "MAP(a, b, c, d)", + "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", }, ) self.validate_all( @@ -476,6 +478,7 @@ class TestHive(Validator): "presto": "MAP(ARRAY[a], ARRAY[b])", "hive": "MAP(a, b)", "spark": "MAP(a, b)", + "snowflake": "OBJECT_CONSTRUCT(a, b)", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index af98249..5064dbe 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -23,6 +23,8 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") self.validate_identity("@@GLOBAL.max_connections") + self.validate_identity("CREATE TABLE A LIKE B") + # SET Commands self.validate_identity("SET @var_name = expr") self.validate_identity("SET @name = 43") @@ -177,14 +179,27 @@ class TestMySQL(Validator): "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", write={ "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')", - "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", + "sqlite": "GROUP_CONCAT(DISTINCT x)", + "tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)", + "postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)", + }, + ) + self.validate_all( + "GROUP_CONCAT(x ORDER BY y SEPARATOR z)", + write={ + "mysql": "GROUP_CONCAT(x ORDER BY y SEPARATOR z)", + "sqlite": "GROUP_CONCAT(x, z)", + "tsql": "STRING_AGG(x, z) WITHIN GROUP (ORDER BY y)", + "postgres": "STRING_AGG(x, z ORDER BY y NULLS FIRST)", }, ) self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", write={ "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", - "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", + "sqlite": "GROUP_CONCAT(DISTINCT x, '')", + "tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)", + "postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)", }, ) self.validate_identity( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 8294eea..cd6117c 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -6,6 +6,9 @@ class TestPostgres(Validator): dialect = "postgres" def test_ddl(self): + self.validate_identity("CREATE TABLE test (foo HSTORE)") + self.validate_identity("CREATE TABLE test (foo JSONB)") + self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", write={ @@ -60,6 +63,12 @@ class TestPostgres(Validator): ) def test_postgres(self): + self.validate_identity("SELECT ARRAY[1, 2, 3]") + self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") + self.validate_identity("STRING_AGG(x, y)") + self.validate_identity("STRING_AGG(x, ',' ORDER BY y)") + self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)") + self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") self.validate_identity( "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" @@ -86,6 +95,14 @@ class TestPostgres(Validator): self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_all( + "END WORK AND NO CHAIN", + write={"postgres": "COMMIT AND NO CHAIN"}, + ) + self.validate_all( + "END AND CHAIN", + write={"postgres": "COMMIT AND CHAIN"}, + ) self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ @@ -95,6 +112,10 @@ class TestPostgres(Validator): "spark": "CREATE TABLE x (a UUID, b BINARY)", }, ) + + self.validate_identity( + "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" + ) self.validate_all( "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", write={ diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 8179cf7..70e1059 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -13,6 +13,7 @@ class TestPresto(Validator): "duckdb": "CAST(a AS INT[])", "presto": "CAST(a AS ARRAY(INTEGER))", "spark": "CAST(a AS ARRAY)", + "snowflake": "CAST(a AS ARRAY)", }, ) self.validate_all( @@ -31,6 +32,7 @@ class TestPresto(Validator): "duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])", "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", "spark": "CAST(ARRAY(1, 2) AS ARRAY)", + "snowflake": "CAST([1, 2] AS ARRAY)", }, ) self.validate_all( @@ -41,6 +43,7 @@ class TestPresto(Validator): "presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))", "hive": "CAST(MAP(1, 1) AS MAP)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP)", + "snowflake": "CAST(OBJECT_CONSTRUCT(1, 1) AS OBJECT)", }, ) self.validate_all( @@ -51,6 +54,7 @@ class TestPresto(Validator): "presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))", "hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP>)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP>)", + "snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)", }, ) self.validate_all( @@ -393,6 +397,7 @@ class TestPresto(Validator): write={ "hive": UnsupportedError, "spark": "MAP_FROM_ARRAYS(a, b)", + "snowflake": UnsupportedError, }, ) self.validate_all( @@ -401,6 +406,7 @@ class TestPresto(Validator): "hive": "MAP(a, c, b, d)", "presto": "MAP(ARRAY[a, b], ARRAY[c, d])", "spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))", + "snowflake": "OBJECT_CONSTRUCT(a, c, b, d)", }, ) self.validate_all( @@ -409,6 +415,7 @@ class TestPresto(Validator): "hive": "MAP('a', 'b')", "presto": "MAP(ARRAY['a'], ARRAY['b'])", "spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))", + "snowflake": "OBJECT_CONSTRUCT('a', 'b')", }, ) self.validate_all( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 5309a34..1943ee3 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -50,6 +50,12 @@ class TestRedshift(Validator): "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5' }, ) + self.validate_all( + "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", + write={ + "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + }, + ) def test_identity(self): self.validate_identity("CAST('bla' AS SUPER)") @@ -64,3 +70,13 @@ class TestRedshift(Validator): self.validate_identity( "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" ) + self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO") + self.validate_identity( + "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)" + ) + self.validate_identity( + "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" + ) + self.validate_identity( + "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0e69f4e..baca269 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -172,13 +172,28 @@ class TestSnowflake(Validator): self.validate_all( "trim(date_column, 'UTC')", write={ + "bigquery": "TRIM(date_column, 'UTC')", "snowflake": "TRIM(date_column, 'UTC')", "postgres": "TRIM('UTC' FROM date_column)", }, ) self.validate_all( "trim(date_column)", - write={"snowflake": "TRIM(date_column)"}, + write={ + "snowflake": "TRIM(date_column)", + "bigquery": "TRIM(date_column)", + }, + ) + self.validate_all( + "DECODE(x, a, b, c, d)", + read={ + "": "MATCHES(x, a, b, c, d)", + }, + write={ + "": "MATCHES(x, a, b, c, d)", + "oracle": "DECODE(x, a, b, c, d)", + "snowflake": "DECODE(x, a, b, c, d)", + }, ) def test_null_treatment(self): @@ -370,7 +385,8 @@ class TestSnowflake(Validator): ) self.validate_all( - r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""} + r"""SELECT * FROM TABLE(?)""", + write={"snowflake": r"""SELECT * FROM TABLE(?)"""}, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 4470722..3a9f918 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -32,13 +32,14 @@ class TestSpark(Validator): "presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))", "hive": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", "spark": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", + "snowflake": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY)", }, ) self.validate_all( "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", write={ "duckdb": "CREATE TABLE x", - "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", + "presto": "CREATE TABLE x WITH (TABLE_FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", }, @@ -94,6 +95,13 @@ TBLPROPERTIES ( pretty=True, ) + self.validate_all( + "CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testData", + write={ + "spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData" + }, + ) + def test_to_date(self): self.validate_all( "TO_DATE(x, 'yyyy-MM-dd')", @@ -271,6 +279,7 @@ TBLPROPERTIES ( "presto": "MAP(ARRAY[1], c)", "hive": "MAP(ARRAY(1), c)", "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", + "snowflake": "OBJECT_CONSTRUCT([1], c)", }, ) self.validate_all( diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 3cc974c..e54a4bc 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -5,6 +5,10 @@ class TestSQLite(Validator): dialect = "sqlite" def test_ddl(self): + self.validate_all( + "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", + write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"}, + ) self.validate_all( """ CREATE TABLE "Track" diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index a60f48d..afdd48a 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -17,7 +17,6 @@ class TestTSQL(Validator): "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", }, ) - self.validate_all( "CONVERT(INT, CONVERT(NUMERIC, '444.75'))", write={ @@ -25,6 +24,33 @@ class TestTSQL(Validator): "tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)", }, ) + self.validate_all( + "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)", + write={ + "tsql": "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)", + "mysql": "GROUP_CONCAT(x ORDER BY z DESC SEPARATOR y)", + "sqlite": "GROUP_CONCAT(x, y)", + "postgres": "STRING_AGG(x, y ORDER BY z DESC NULLS LAST)", + }, + ) + self.validate_all( + "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)", + write={ + "tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)", + "mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')", + "sqlite": "GROUP_CONCAT(x, '|')", + "postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)", + }, + ) + self.validate_all( + "STRING_AGG(x, '|')", + write={ + "tsql": "STRING_AGG(x, '|')", + "mysql": "GROUP_CONCAT(x SEPARATOR '|')", + "sqlite": "GROUP_CONCAT(x, '|')", + "postgres": "STRING_AGG(x, '|')", + }, + ) def test_types(self): self.validate_identity("CAST(x AS XML)") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 75bd25d..06ab96d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -34,6 +34,7 @@ x >> 1 x >> 1 | 1 & 1 ^ 1 x || y 1 - -1 +- -5 dec.x + y a.filter a.b.c @@ -438,6 +439,7 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) +CREATE TABLE foo (id INT PRIMARY KEY ASC) CREATE TABLE a.b AS SELECT 1 CREATE TABLE a.b AS SELECT a FROM a.c CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d @@ -579,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) SELECT CAST(x AS INT) /* comment */ FROM foo SELECT a /* x */, b /* x */ +SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */ SELECT * FROM foo /* x */, bla /* x */ SELECT 1 /* comment */ + 1 SELECT 1 /* c1 */ + 2 /* c2 */ @@ -588,3 +591,7 @@ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' +SELECT x AS INTO FROM bla +SELECT * INTO newevent FROM event +SELECT * INTO TEMPORARY newevent FROM event +SELECT * INTO UNLOGGED newevent FROM event diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql index f395c0a..c566657 100644 --- a/tests/fixtures/optimizer/eliminate_subqueries.sql +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -77,3 +77,15 @@ WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x -- Existing duplicate CTE WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z; WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z; + +-- Nested CTE +WITH cte1 AS (SELECT a FROM x) SELECT a FROM (WITH cte2 AS (SELECT a FROM cte1) SELECT a FROM cte2); +WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1), cte AS (SELECT a FROM cte2 AS cte2) SELECT a FROM cte AS cte; + +-- Nested CTE inside CTE +WITH cte1 AS (WITH cte2 AS (SELECT a FROM x) SELECT t.a FROM cte2 AS t) SELECT a FROM cte1; +WITH cte2 AS (SELECT a FROM x), cte1 AS (SELECT t.a FROM cte2 AS t) SELECT a FROM cte1; + +-- Duplicate CTE nested in CTE +WITH cte1 AS (SELECT a FROM x), cte2 AS (WITH cte3 AS (SELECT a FROM x) SELECT a FROM cte3) SELECT a FROM cte2; +WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FROM cte2; diff --git a/tests/fixtures/optimizer/lower_identities.sql b/tests/fixtures/optimizer/lower_identities.sql new file mode 100644 index 0000000..cea346f --- /dev/null +++ b/tests/fixtures/optimizer/lower_identities.sql @@ -0,0 +1,41 @@ +SELECT a FROM x; +SELECT a FROM x; + +SELECT "A" FROM "X"; +SELECT "A" FROM "X"; + +SELECT a AS A FROM x; +SELECT a AS A FROM x; + +SELECT * FROM x; +SELECT * FROM x; + +SELECT A FROM x; +SELECT a FROM x; + +SELECT a FROM X; +SELECT a FROM x; + +SELECT A AS A FROM (SELECT a AS A FROM x); +SELECT a AS A FROM (SELECT a AS a FROM x); + +SELECT a AS B FROM x ORDER BY B; +SELECT a AS B FROM x ORDER BY B; + +SELECT A FROM x ORDER BY A; +SELECT a FROM x ORDER BY a; + +SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0; +SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0; + +SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0; +SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0; + +SELECT A FROM X UNION SELECT A FROM X; +SELECT a FROM x UNION SELECT a FROM x; + +SELECT A AS A FROM X UNION SELECT A AS A FROM X; +SELECT a AS A FROM x UNION SELECT a AS A FROM x; + +(SELECT A AS A FROM X); +(SELECT a AS A FROM x); diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a1e531b..a692c7d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -276,3 +276,18 @@ SELECT /*+ COALESCE(3), FROM `x` AS `x` JOIN `y` AS `y` ON `x`.`b` = `y`.`b`; + +WITH cte1 AS ( + WITH cte2 AS ( + SELECT a, b FROM x + ) + SELECT a1 + FROM ( + WITH cte3 AS (SELECT 1) + SELECT a AS a1, b AS b1 FROM cte2 + ) +) +SELECT a1 FROM cte1; +SELECT + "x"."a" AS "a1" +FROM "x" AS "x"; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 7207ba2..d9c7779 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -274,6 +274,15 @@ TRUE; -(-1); 1; +- -+1; +1; + ++-1; +-1; + +++1; +1; + 0.06 - 0.01; 0.05; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 8138b11..4893743 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -666,11 +666,20 @@ WITH "supplier_2" AS ( FROM "nation" AS "nation" WHERE "nation"."n_name" = 'GERMANY' +), "_u_0" AS ( + SELECT + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" + FROM "partsupp" AS "partsupp" + JOIN "supplier_2" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" + JOIN "nation_2" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" ) SELECT "partsupp"."ps_partkey" AS "ps_partkey", SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" FROM "partsupp" AS "partsupp" +CROSS JOIN "_u_0" AS "_u_0" JOIN "supplier_2" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" JOIN "nation_2" AS "nation" @@ -678,15 +687,7 @@ JOIN "nation_2" AS "nation" GROUP BY "partsupp"."ps_partkey" HAVING - SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( - SELECT - SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" - FROM "partsupp" AS "partsupp" - JOIN "supplier_2" AS "supplier" - ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" - JOIN "nation_2" AS "nation" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" - ) + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0") ORDER BY "value" DESC; @@ -880,6 +881,10 @@ WITH "revenue" AS ( AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE) GROUP BY "lineitem"."l_suppkey" +), "_u_0" AS ( + SELECT + MAX("revenue"."total_revenue") AS "_col_0" + FROM "revenue" ) SELECT "supplier"."s_suppkey" AS "s_suppkey", @@ -889,12 +894,9 @@ SELECT "revenue"."total_revenue" AS "total_revenue" FROM "supplier" AS "supplier" JOIN "revenue" - ON "revenue"."total_revenue" = ( - SELECT - MAX("revenue"."total_revenue") AS "_col_0" - FROM "revenue" - ) - AND "supplier"."s_suppkey" = "revenue"."supplier_no" + ON "supplier"."s_suppkey" = "revenue"."supplier_no" +JOIN "_u_0" AS "_u_0" + ON "revenue"."total_revenue" = "_u_0"."_col_0" ORDER BY "s_suppkey"; @@ -1395,7 +1397,14 @@ order by cntrycode; WITH "_u_0" AS ( SELECT - "orders"."o_custkey" AS "_u_1" + AVG("customer"."c_acctbal") AS "_col_0" + FROM "customer" AS "customer" + WHERE + "customer"."c_acctbal" > 0.00 + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') +), "_u_1" AS ( + SELECT + "orders"."o_custkey" AS "_u_2" FROM "orders" AS "orders" GROUP BY "orders"."o_custkey" @@ -1405,18 +1414,12 @@ SELECT COUNT(*) AS "numcust", SUM("customer"."c_acctbal") AS "totacctbal" FROM "customer" AS "customer" -LEFT JOIN "_u_0" AS "_u_0" - ON "_u_0"."_u_1" = "customer"."c_custkey" +JOIN "_u_0" AS "_u_0" + ON "customer"."c_acctbal" > "_u_0"."_col_0" +LEFT JOIN "_u_1" AS "_u_1" + ON "_u_1"."_u_2" = "customer"."c_custkey" WHERE - "_u_0"."_u_1" IS NULL - AND "customer"."c_acctbal" > ( - SELECT - AVG("customer"."c_acctbal") AS "_col_0" - FROM "customer" AS "customer" - WHERE - "customer"."c_acctbal" > 0.00 - AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') - ) + "_u_1"."_u_2" IS NULL AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') GROUP BY SUBSTRING("customer"."c_phone", 1, 2) diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index f53121a..dc373a0 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -1,10 +1,12 @@ +--SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x; -------------------------------------- -- Unnest Subqueries -------------------------------------- SELECT * FROM x AS x WHERE - x.a IN (SELECT y.a AS a FROM y) + x.a = (SELECT SUM(y.a) AS a FROM y) + AND x.a IN (SELECT y.a AS a FROM y) AND x.a IN (SELECT y.b AS b FROM y) AND x.a = ANY (SELECT y.a AS a FROM y) AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) @@ -24,52 +26,57 @@ WHERE SELECT * FROM x AS x +CROSS JOIN ( + SELECT + SUM(y.a) AS a + FROM y +) AS "_u_0" LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_0" - ON x.a = "_u_0"."a" +) AS "_u_1" + ON x.a = "_u_1"."a" LEFT JOIN ( SELECT y.b AS b FROM y GROUP BY y.b -) AS "_u_1" - ON x.a = "_u_1"."b" +) AS "_u_2" + ON x.a = "_u_2"."b" LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_2" - ON x.a = "_u_2"."a" +) AS "_u_3" + ON x.a = "_u_3"."a" LEFT JOIN ( SELECT SUM(y.b) AS b, - y.a AS _u_4 + y.a AS _u_5 FROM y WHERE TRUE GROUP BY y.a -) AS "_u_3" - ON x.a = "_u_3"."_u_4" +) AS "_u_4" + ON x.a = "_u_4"."_u_5" LEFT JOIN ( SELECT SUM(y.b) AS b, - y.a AS _u_6 + y.a AS _u_7 FROM y WHERE TRUE GROUP BY y.a -) AS "_u_5" - ON x.a = "_u_5"."_u_6" +) AS "_u_6" + ON x.a = "_u_6"."_u_7" LEFT JOIN ( SELECT y.a AS a @@ -78,8 +85,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_7" - ON "_u_7".a = x.a +) AS "_u_8" + ON "_u_8".a = x.a LEFT JOIN ( SELECT y.a AS a @@ -88,31 +95,31 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_8" - ON "_u_8".a = x.a +) AS "_u_9" + ON "_u_9".a = x.a LEFT JOIN ( SELECT ARRAY_AGG(y.a) AS a, - y.b AS _u_10 + y.b AS _u_11 FROM y WHERE TRUE GROUP BY y.b -) AS "_u_9" - ON "_u_9"."_u_10" = x.a +) AS "_u_10" + ON "_u_10"."_u_11" = x.a LEFT JOIN ( SELECT SUM(y.a) AS a, - y.a AS _u_12, - ARRAY_AGG(y.b) AS _u_13 + y.a AS _u_13, + ARRAY_AGG(y.b) AS _u_14 FROM y WHERE TRUE AND TRUE AND TRUE GROUP BY y.a -) AS "_u_11" - ON "_u_11"."_u_12" = x.a AND "_u_11"."_u_12" = x.b +) AS "_u_12" + ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b LEFT JOIN ( SELECT y.a AS a @@ -121,37 +128,38 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_14" - ON x.a = "_u_14".a +) AS "_u_15" + ON x.a = "_u_15".a WHERE - NOT "_u_0"."a" IS NULL - AND NOT "_u_1"."b" IS NULL - AND NOT "_u_2"."a" IS NULL + x.a = "_u_0".a + AND NOT "_u_1"."a" IS NULL + AND NOT "_u_2"."b" IS NULL + AND NOT "_u_3"."a" IS NULL AND ( - x.a = "_u_3".b AND NOT "_u_3"."_u_4" IS NULL + x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL ) AND ( - x.a > "_u_5".b AND NOT "_u_5"."_u_6" IS NULL + x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL ) AND ( - None = "_u_7".a AND NOT "_u_7".a IS NULL + None = "_u_8".a AND NOT "_u_8".a IS NULL ) AND NOT ( - x.a = "_u_8".a AND NOT "_u_8".a IS NULL + x.a = "_u_9".a AND NOT "_u_9".a IS NULL ) AND ( - ARRAY_ANY("_u_9".a, _x -> _x = x.a) AND NOT "_u_9"."_u_10" IS NULL + ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL ) AND ( ( ( - x.a < "_u_11".a AND NOT "_u_11"."_u_12" IS NULL - ) AND NOT "_u_11"."_u_12" IS NULL + x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL + ) AND NOT "_u_12"."_u_13" IS NULL ) - AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d) + AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d) ) AND ( - NOT "_u_14".a IS NULL AND NOT "_u_14".a IS NULL + NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL ) AND x.a IN ( SELECT diff --git a/tests/test_executor.py b/tests/test_executor.py index 2c4d7cd..9d452e4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase): def test_execute_tpch(self): def to_csv(expression): - if isinstance(expression, exp.Table): + if isinstance(expression, exp.Table) and expression.name not in ("revenue"): return parse_one( f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression - for i, (sql, _) in enumerate(self.sqls[0:7]): + for i, (sql, _) in enumerate(self.sqls[0:16]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) @@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase): ["a"], [("a",)], ), + ( + "SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)", + ["a"], + [(1,)], + ), + ( + "SELECT DISTINCT a, SUM(b) AS b " + "FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) " + "GROUP BY a " + "LIMIT 1", + ["a", "b"], + [("a", 3)], + ), + ( + "SELECT COUNT(1) AS a FROM (SELECT 1)", + ["a"], + [(1,)], + ), + ( + "SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0", + ["a"], + [], + ), + ( + "SELECT a FROM x GROUP BY a LIMIT 0", + ["a"], + [], + ), + ( + "SELECT a FROM x LIMIT 0", + ["a"], + [], + ), ]: with self.subTest(sql): result = execute(sql, schema=schema, tables=tables) @@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase): ], ) + def test_execute_subqueries(self): + tables = { + "table": [ + {"a": 1, "b": 1}, + {"a": 2, "b": 2}, + ], + } + + self.assertEqual( + execute( + """ + SELECT * + FROM table + WHERE a = (SELECT MAX(a) FROM table) + """, + tables=tables, + ).rows, + [ + (2, 2), + ], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}} @@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase): ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), + ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]), ]: result = execute(sql) self.assertEqual(result.columns, tuple(cols)) @@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase): ("IF(false, 1, 0)", 0), ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"), ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)), + ("1 IN (1, 2, 3)", True), + ("1 IN (2, 3)", False), + ("NULL IS NULL", True), + ("NULL IS NOT NULL", False), + ("NULL = NULL", None), + ("NULL <> NULL", None), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") self.assertEqual(result.rows, [(expected,)]) + + def test_case_sensitivity(self): + result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]}) + self.assertEqual(result.columns, ("A",)) + self.assertEqual(result.rows, [(1,)]) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index c0927ad..0e13ade 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -525,24 +525,14 @@ class TestExpressions(unittest.TestCase): ), exp.Properties( expressions=[ - exp.FileFormatProperty( - this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet") - ), + exp.FileFormatProperty(this=exp.Literal.string("parquet")), exp.PartitionedByProperty( - this=exp.Literal.string("PARTITIONED_BY"), - value=exp.Tuple( - expressions=[exp.to_identifier("a"), exp.to_identifier("b")] - ), - ), - exp.AnonymousProperty( - this=exp.Literal.string("custom"), value=exp.Literal.number(1) - ), - exp.TableFormatProperty( - this=exp.Literal.string("TABLE_FORMAT"), - value=exp.to_identifier("test_format"), + this=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]) ), - exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()), - exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()), + exp.Property(this=exp.Literal.string("custom"), value=exp.Literal.number(1)), + exp.TableFormatProperty(this=exp.to_identifier("test_format")), + exp.EngineProperty(this=exp.null()), + exp.CollateProperty(this=exp.true()), ] ), ) @@ -609,9 +599,9 @@ FROM foo""", """SELECT a, b AS B, - c, -- comment - d AS D, -- another comment - CAST(x AS INT) -- final comment + c, /* comment */ + d AS D, /* another comment */ + CAST(x AS INT) /* final comment */ FROM foo""", ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 6637a1d..ecf581d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -85,9 +85,8 @@ class TestOptimizer(unittest.TestCase): if leave_tables_isolated is not None: func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) - optimized = func(parse_one(sql, read=dialect), **func_kwargs) - with self.subTest(title): + optimized = func(parse_one(sql, read=dialect), **func_kwargs) self.assertEqual( expected, optimized.sql(pretty=pretty, dialect=dialect), @@ -168,6 +167,9 @@ class TestOptimizer(unittest.TestCase): def test_quote_identities(self): self.check_file("quote_identities", optimizer.quote_identities.quote_identities) + def test_lower_identities(self): + self.check_file("lower_identities", optimizer.lower_identities.lower_identities) + def test_pushdown_projection(self): def pushdown_projections(expression, **kwargs): expression = optimizer.qualify_tables.qualify_tables(expression) diff --git a/tests/test_parser.py b/tests/test_parser.py index c747ea3..fa7b589 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -15,6 +15,51 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("array", into=exp.DataType), exp.DataType) + def test_parse_into_error(self): + expected_message = "Failed to parse into []" + expected_errors = [ + { + "description": "Invalid expression / Unexpected token", + "line": 1, + "col": 1, + "start_context": "", + "highlight": "SELECT", + "end_context": " 1;", + "into_expression": exp.From, + } + ] + with self.assertRaises(ParseError) as ctx: + parse_one("SELECT 1;", "sqlite", [exp.From]) + self.assertEqual(str(ctx.exception), expected_message) + self.assertEqual(ctx.exception.errors, expected_errors) + + def test_parse_into_errors(self): + expected_message = "Failed to parse into [, ]" + expected_errors = [ + { + "description": "Invalid expression / Unexpected token", + "line": 1, + "col": 1, + "start_context": "", + "highlight": "SELECT", + "end_context": " 1;", + "into_expression": exp.From, + }, + { + "description": "Invalid expression / Unexpected token", + "line": 1, + "col": 1, + "start_context": "", + "highlight": "SELECT", + "end_context": " 1;", + "into_expression": exp.Join, + }, + ] + with self.assertRaises(ParseError) as ctx: + parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join]) + self.assertEqual(str(ctx.exception), expected_message) + self.assertEqual(ctx.exception.errors, expected_errors) + def test_column(self): columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column) assert len(list(columns)) == 1 @@ -24,6 +69,9 @@ class TestParser(unittest.TestCase): def test_float(self): self.assertEqual(parse_one(".2"), parse_one("0.2")) + def test_unary_plus(self): + self.assertEqual(parse_one("+15"), exp.Literal.number(15)) + def test_table(self): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] self.assertEqual(tables, ["a", "b.c", "d"]) @@ -157,8 +205,9 @@ class TestParser(unittest.TestCase): def test_comments(self): expression = parse_one( """ - --comment1 - SELECT /* this won't be used */ + --comment1.1 + --comment1.2 + SELECT /*comment1.3*/ a, --comment2 b as B, --comment3:testing "test--annotation", @@ -169,13 +218,13 @@ class TestParser(unittest.TestCase): """ ) - self.assertEqual(expression.comment, "comment1") - self.assertEqual(expression.expressions[0].comment, "comment2") - self.assertEqual(expression.expressions[1].comment, "comment3:testing") - self.assertEqual(expression.expressions[2].comment, None) - self.assertEqual(expression.expressions[3].comment, "comment4 --foo") - self.assertEqual(expression.expressions[4].comment, "") - self.assertEqual(expression.expressions[5].comment, " space") + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.expressions[0].comments, ["comment2"]) + self.assertEqual(expression.expressions[1].comments, ["comment3:testing"]) + self.assertEqual(expression.expressions[2].comments, None) + self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"]) + self.assertEqual(expression.expressions[4].comments, [""]) + self.assertEqual(expression.expressions[5].comments, [" space"]) def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index d4772ba..1d1b966 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -7,13 +7,13 @@ class TestTokens(unittest.TestCase): def test_comment_attachment(self): tokenizer = Tokenizer() sql_comment = [ - ("/*comment*/ foo", "comment"), - ("/*comment*/ foo --test", "comment"), - ("--comment\nfoo --test", "comment"), - ("foo --comment", "comment"), - ("foo", None), - ("foo /*comment 1*/ /*comment 2*/", "comment 1"), + ("/*comment*/ foo", ["comment"]), + ("/*comment*/ foo --test", ["comment", "test"]), + ("--comment\nfoo --test", ["comment", "test"]), + ("foo --comment", ["comment"]), + ("foo", []), + ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]), ] for sql, comment in sql_comment: - self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment) + self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1928d2c..0bcd2ca 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,7 +1,7 @@ import unittest from sqlglot import parse_one -from sqlglot.transforms import unalias_group +from sqlglot.transforms import eliminate_distinct_on, unalias_group class TestTime(unittest.TestCase): @@ -35,3 +35,30 @@ class TestTime(unittest.TestCase): "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", ) + + def test_eliminate_distinct_on(self): + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (a) a, b FROM x", + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1', + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC", + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT a, b FROM x ORDER BY c DESC", + "SELECT DISTINCT a, b FROM x ORDER BY c DESC", + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", + 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1', + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 1bd2527..7bf53e5 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase): ) self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date") self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") + self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") for key in ("union", "filter", "over", "from", "join"): with self.subTest(f"alias {key}"): @@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase): def test_asc(self): self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") + def test_unary(self): + self.validate("+++1", "1") + self.validate("+-1", "-1") + self.validate("+- - -1", "- - -1") + def test_paren(self): with self.assertRaises(ParseError): transpile("1 + (2 + 3") @@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase): ) self.validate( "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", - "SELECT\n FOO -- x\n , BAR -- y\n , BAZ", + "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ", leading_comma=True, pretty=True, ) @@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase): def test_comments(self): self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( - "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" + "SELECT * FROM table /*comment 1*/ /*comment 2*/", + "SELECT * FROM table /* comment 1 */ /* comment 2 */", ) self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") @@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase): ) self.validate( """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo + """, + "/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo", + ) + self.validate( + """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo""", + """/* comment 1 */ +/* comment 2 */ +/* comment 3 */ +SELECT + * +FROM foo""", + pretty=True, + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT * FROM tbl /* line1 +line2 +line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""", + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT + * +FROM tbl /* line1 +line2 +line3 */ +/* another comment */ +WHERE + 1 = 1 /* comment at the end */""", + pretty=True, + ) + self.validate( + """ /* multi line comment @@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase): */ SELECT tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, - CAST(x AS INT), -- comment 3 - y -- comment 4 + CAST(x AS INT), /* comment 3 */ + y /* comment 4 */ FROM bar /* comment 5 */, tbl /* comment 6 */""", read="mysql", pretty=True, @@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): invalid = "x + 1. (" - errors = [ + expected_messages = [ "Required keyword: 'expressions' missing for . Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", ] + expected_errors = [ + { + "description": "Required keyword: 'expressions' missing for ", + "line": 1, + "col": 8, + "start_context": "x + 1. ", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + { + "description": "Expecting )", + "line": 1, + "col": 8, + "start_context": "x + 1. ", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + ] transpile(invalid, error_level=ErrorLevel.WARN) - for error in errors: + for error in expected_messages: assert_logger_contains(error, logger) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.IMMEDIATE) - self.assertEqual(str(ctx.exception), errors[0]) + self.assertEqual(str(ctx.exception), expected_messages[0]) + self.assertEqual(ctx.exception.errors[0], expected_errors[0]) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.RAISE) - self.assertEqual(str(ctx.exception), "\n\n".join(errors)) + self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages)) + self.assertEqual(ctx.exception.errors, expected_errors) more_than_max_errors = "((((" - expected = ( + expected_messages = ( "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Required keyword: 'this' missing for . Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "... and 2 more" ) + expected_errors = [ + { + "description": "Expecting )", + "line": 1, + "col": 4, + "start_context": "(((", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + { + "description": "Required keyword: 'this' missing for ", + "line": 1, + "col": 4, + "start_context": "(((", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + ] + # Also expect three trailing structured errors that match the first + expected_errors += [expected_errors[0]] * 3 + with self.assertRaises(ParseError) as ctx: transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) - self.assertEqual(str(ctx.exception), expected) + self.assertEqual(str(ctx.exception), expected_messages) + self.assertEqual(ctx.exception.errors, expected_errors) @mock.patch("sqlglot.generator.logger") def test_unsupported_level(self, logger): -- cgit v1.2.3