From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- CHANGELOG.md | 11 + LICENSE | 2 +- run_checks.sh | 2 +- sqlglot/__init__.py | 2 +- sqlglot/__main__.py | 7 +- sqlglot/dialects/__init__.py | 1 + sqlglot/dialects/bigquery.py | 9 +- sqlglot/dialects/dialect.py | 31 +- sqlglot/dialects/duckdb.py | 5 +- sqlglot/dialects/hive.py | 15 +- sqlglot/dialects/mysql.py | 29 + sqlglot/dialects/oracle.py | 8 + sqlglot/dialects/postgres.py | 116 +++- sqlglot/dialects/presto.py | 6 +- sqlglot/dialects/redshift.py | 34 + sqlglot/dialects/snowflake.py | 4 +- sqlglot/dialects/spark.py | 15 +- sqlglot/dialects/sqlite.py | 1 + sqlglot/dialects/trino.py | 3 + sqlglot/diff.py | 35 +- sqlglot/executor/__init__.py | 10 +- sqlglot/executor/context.py | 4 +- sqlglot/executor/python.py | 14 +- sqlglot/executor/table.py | 5 +- sqlglot/expressions.py | 169 +++-- sqlglot/generator.py | 167 ++--- sqlglot/optimizer/__init__.py | 2 +- sqlglot/optimizer/isolate_table_selects.py | 4 +- sqlglot/optimizer/merge_derived_tables.py | 232 +++++++ sqlglot/optimizer/normalize.py | 22 +- sqlglot/optimizer/optimize_joins.py | 6 +- sqlglot/optimizer/optimizer.py | 39 +- sqlglot/optimizer/pushdown_predicates.py | 20 +- sqlglot/optimizer/qualify_columns.py | 36 +- sqlglot/optimizer/qualify_tables.py | 4 +- sqlglot/optimizer/schema.py | 4 +- sqlglot/optimizer/scope.py | 58 +- sqlglot/optimizer/simplify.py | 8 +- sqlglot/optimizer/unnest_subqueries.py | 22 +- sqlglot/parser.py | 404 ++++++------ sqlglot/planner.py | 21 +- sqlglot/tokens.py | 184 ++++-- sqlglot/transforms.py | 4 +- tests/dialects/test_dialect.py | 133 +++- tests/dialects/test_hive.py | 15 + tests/dialects/test_mysql.py | 52 +- tests/dialects/test_postgres.py | 93 ++- tests/dialects/test_redshift.py | 64 ++ tests/dialects/test_snowflake.py | 32 + tests/dialects/test_sqlite.py | 18 + tests/fixtures/identity.sql | 7 + tests/fixtures/optimizer/merge_derived_tables.sql | 63 ++ tests/fixtures/optimizer/optimizer.sql | 57 +- tests/fixtures/optimizer/tpc-h/tpc-h.sql | 761 +++++----------------- tests/helpers.py | 8 +- tests/test_build.py | 127 +--- tests/test_executor.py | 20 +- tests/test_expressions.py | 53 +- tests/test_optimizer.py | 33 +- tests/test_parser.py | 37 +- tests/test_transpile.py | 51 +- 61 files changed, 1844 insertions(+), 1555 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 sqlglot/dialects/redshift.py create mode 100644 sqlglot/optimizer/merge_derived_tables.py create mode 100644 tests/dialects/test_redshift.py create mode 100644 tests/fixtures/optimizer/merge_derived_tables.sql diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0eba6cc --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,11 @@ +Changelog +========= + +v6.1.0 +------ + +Changes: + +- New: mysql group\_concat separator [49a4099](https://github.com/tobymao/sqlglot/commit/49a4099adc93780eeffef8204af36559eab50a9f) + +- Improvement: Better nested select parsing [45603f](https://github.com/tobymao/sqlglot/commit/45603f14bf9146dc3f8b330b85a0e25b77630b9b) diff --git a/LICENSE b/LICENSE index 388cd5e..05dbdae 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 Toby Mao +Copyright (c) 2022 Toby Mao Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/run_checks.sh b/run_checks.sh index a7dddf4..770f443 100755 --- a/run_checks.sh +++ b/run_checks.sh @@ -8,5 +8,5 @@ python -m autoflake -i -r \ --remove-unused-variables \ sqlglot/ tests/ python -m isort --profile black sqlglot/ tests/ -python -m black sqlglot/ tests/ +python -m black --line-length 120 sqlglot/ tests/ python -m unittest diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 0007e34..3fa40ce 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -20,7 +20,7 @@ from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.0.4" +__version__ = "6.1.1" pretty = False diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index 25200c4..4161259 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -49,12 +49,7 @@ args = parser.parse_args() error_level = sqlglot.ErrorLevel[args.error_level.upper()] if args.parse: - sqls = [ - repr(expression) - for expression in sqlglot.parse( - args.sql, read=args.read, error_level=error_level - ) - ] + sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)] else: sqls = sqlglot.transpile( args.sql, diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 5aa7d77..f7d03ad 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.oracle import Oracle from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.presto import Presto +from sqlglot.dialects.redshift import Redshift from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.spark import Spark from sqlglot.dialects.sqlite import SQLite diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index f4e87c3..1f1f90a 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -44,6 +44,7 @@ class BigQuery(Dialect): ] IDENTIFIERS = ["`"] ESCAPE = "\\" + HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **Tokenizer.KEYWORDS, @@ -120,9 +121,5 @@ class BigQuery(Dialect): def intersect_op(self, expression): if not expression.args.get("distinct", False): - self.unsupported( - "INTERSECT without DISTINCT is not supported in BigQuery" - ) - return ( - f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" - ) + self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") + return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8045f7a..f338c81 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -20,6 +20,7 @@ class Dialects(str, Enum): ORACLE = "oracle" POSTGRES = "postgres" PRESTO = "presto" + REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SPARK = "spark" SQLITE = "sqlite" @@ -53,12 +54,19 @@ class _Dialect(type): klass.generator_class = getattr(klass, "Generator", Generator) klass.tokenizer = klass.tokenizer_class() - klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[ - 0 - ] - klass.identifier_start, klass.identifier_end = list( - klass.tokenizer_class.IDENTIFIERS.items() - )[0] + klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] + klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] + + if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: + bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] + klass.generator_class.TRANSFORMS[ + exp.BitString + ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" + if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: + hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] + klass.generator_class.TRANSFORMS[ + exp.HexString + ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" return klass @@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect): return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) def parse_into(self, expression_type, sql, **opts): - return self.parser(**opts).parse_into( - expression_type, self.tokenizer.tokenize(sql), sql - ) + return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) def generate(self, expression, **opts): return self.generator(**opts).generate(expression) @@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect): def rename_func(name): - return ( - lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" - ) + return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" def approx_count_distinct_sql(self, expression): @@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None): return exp_class( this=list_get(args, 0), format=Dialect[dialect].format_time( - list_get(args, 1) - or (Dialect[dialect].time_format if default is True else default) + list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) ), ) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d83a620..ff3a8b1 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -63,10 +63,7 @@ def _sort_array_reverse(args): def _struct_pack_sql(self, expression): - args = [ - self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) - for e in expression.expressions - ] + args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions] return f"STRUCT_PACK({', '.join(args)})" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index e3f3f39..59aa8fa 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression): alias=exp.TableAlias(this=alias.this, columns=[column]), ) ) - for expression, column in zip( - unnest.expressions, alias.columns if alias else [] - ) + for expression, column in zip(unnest.expressions, alias.columns if alias else []) ) return self.join_sql(expression) @@ -206,14 +204,11 @@ class Hive(Dialect): substr=list_get(args, 0), position=list_get(args, 2), ), - "LOG": ( - lambda args: exp.Log.from_arg_list(args) - if len(args) > 1 - else exp.Ln.from_arg_list(args) - ), + "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "MAP": _parse_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, + "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, "COLLECT_SET": exp.SetAgg.from_arg_list, "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, @@ -262,6 +257,7 @@ class Hive(Dialect): HiveMap: _map_sql, exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", exp.Quantile: rename_func("PERCENTILE"), + exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.SafeDivide: no_safe_divide_sql, @@ -296,8 +292,7 @@ class Hive(Dialect): def datatype_sql(self, expression): if ( - expression.this - in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) + expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) and not expression.expressions ): expression = exp.DataType.build("text") diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 93800a6..87a2c41 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression): return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" +def _trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific + if not remove_chars: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target})" + + def _date_add(expression_class): def func(args): interval = list_get(args, 1) @@ -88,9 +103,12 @@ class MySQL(Dialect): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] + BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] + HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] KEYWORDS = { **Tokenizer.KEYWORDS, + "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -145,6 +163,15 @@ class MySQL(Dialect): "STR_TO_DATE": _str_to_date, } + FUNCTION_PARSERS = { + **Parser.FUNCTION_PARSERS, + "GROUP_CONCAT": lambda self: self.expression( + exp.GroupConcat, + this=self._parse_lambda(), + separator=self._match(TokenType.SEPARATOR) and self._parse_field(), + ), + } + class Generator(Generator): NULL_ORDERING_SUPPORTED = False @@ -158,6 +185,8 @@ class MySQL(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, + exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, + exp.Trim: _trim_sql, } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 9c8b6f2..91e30b2 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -51,6 +51,14 @@ class Oracle(Dialect): sep="", ) + def alias_sql(self, expression): + if isinstance(expression.this, exp.Table): + to_sql = self.sql(expression, "alias") + # oracle does not allow "AS" between table and alias + to_sql = f" {to_sql}" if to_sql else "" + return f"{self.sql(expression, 'this')}{to_sql}" + return super().alias_sql(expression) + def offset_sql(self, expression): return f"{super().offset_sql(expression)} ROWS" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 61dff86..c796839 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.transforms import delegate, preprocess def _date_add_sql(kind): @@ -32,11 +33,96 @@ def _date_add_sql(kind): return func +def _lateral_sql(self, expression): + this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subquery): + return f"LATERAL{self.sep()}{this}" + alias = expression.args["alias"] + table = alias.name + table = f" {table}" if table else table + columns = self.expressions(alias, key="columns", flat=True) + columns = f" AS {columns}" if columns else "" + return f"LATERAL{self.sep()}{this}{table}{columns}" + + +def _substring_sql(self, expression): + this = self.sql(expression, "this") + start = self.sql(expression, "start") + length = self.sql(expression, "length") + + from_part = f" FROM {start}" if start else "" + for_part = f" FOR {length}" if length else "" + + return f"SUBSTRING({this}{from_part}{for_part})" + + +def _trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + collation = self.sql(expression, "collation") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific + if not remove_chars and not collation: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + collation = f" COLLATE {collation}" if collation else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" + + +def _auto_increment_to_serial(expression): + auto = expression.find(exp.AutoIncrementColumnConstraint) + + if auto: + expression = expression.copy() + expression.args["constraints"].remove(auto.parent) + kind = expression.args["kind"] + + if kind.this == exp.DataType.Type.INT: + kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL)) + elif kind.this == exp.DataType.Type.SMALLINT: + kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL)) + elif kind.this == exp.DataType.Type.BIGINT: + kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL)) + + return expression + + +def _serial_to_generated(expression): + kind = expression.args["kind"] + + if kind.this == exp.DataType.Type.SERIAL: + data_type = exp.DataType(this=exp.DataType.Type.INT) + elif kind.this == exp.DataType.Type.SMALLSERIAL: + data_type = exp.DataType(this=exp.DataType.Type.SMALLINT) + elif kind.this == exp.DataType.Type.BIGSERIAL: + data_type = exp.DataType(this=exp.DataType.Type.BIGINT) + else: + data_type = None + + if data_type: + expression = expression.copy() + expression.args["kind"].replace(data_type) + constraints = expression.args["constraints"] + generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) + notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) + if notnull not in constraints: + constraints.insert(0, notnull) + if generated not in constraints: + constraints.insert(0, generated) + + return expression + + class Postgres(Dialect): null_ordering = "nulls_are_large" time_format = "'YYYY-MM-DD HH24:MI:SS'" time_mapping = { - "AM": "%p", # AM or PM + "AM": "%p", + "PM": "%p", "D": "%w", # 1-based day of week "DD": "%d", # day of month "DDD": "%j", # zero padded day of year @@ -65,14 +151,25 @@ class Postgres(Dialect): } class Tokenizer(Tokenizer): + BIT_STRINGS = [("b'", "'"), ("B'", "'")] + HEX_STRINGS = [("x'", "'"), ("X'", "'")] KEYWORDS = { **Tokenizer.KEYWORDS, - "SERIAL": TokenType.AUTO_INCREMENT, + "ALWAYS": TokenType.ALWAYS, + "BY DEFAULT": TokenType.BY_DEFAULT, + "IDENTITY": TokenType.IDENTITY, + "FOR": TokenType.FOR, + "GENERATED": TokenType.GENERATED, + "DOUBLE PRECISION": TokenType.DOUBLE, + "BIGSERIAL": TokenType.BIGSERIAL, + "SERIAL": TokenType.SERIAL, + "SMALLSERIAL": TokenType.SMALLSERIAL, "UUID": TokenType.UUID, } class Parser(Parser): STRICT_CAST = False + FUNCTIONS = { **Parser.FUNCTIONS, "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), @@ -86,14 +183,18 @@ class Postgres(Dialect): exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", - } - - TOKEN_MAPPING = { - TokenType.AUTO_INCREMENT: "SERIAL", + exp.DataType.Type.DATETIME: "TIMESTAMP", } TRANSFORMS = { **Generator.TRANSFORMS, + exp.ColumnDef: preprocess( + [ + _auto_increment_to_serial, + _serial_to_generated, + ], + delegate("columndef_sql"), + ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", @@ -102,8 +203,11 @@ class Postgres(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), + exp.Lateral: _lateral_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, + exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index ca913e4..7253f7e 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression): time_format = self.format_time(expression) if time_format and time_format not in (Presto.time_format, Presto.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" - return ( - f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" - ) + return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" def _ts_or_ds_add_sql(self, expression): @@ -141,6 +139,7 @@ class Presto(Dialect): "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, "STRPOS": exp.StrPosition.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } class Generator(Generator): @@ -193,6 +192,7 @@ class Presto(Dialect): exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", exp.Quantile: _quantile_sql, + exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.SortArray: _no_sort_array, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py new file mode 100644 index 0000000..e1f7b78 --- /dev/null +++ b/sqlglot/dialects/redshift.py @@ -0,0 +1,34 @@ +from sqlglot import exp +from sqlglot.dialects.postgres import Postgres +from sqlglot.tokens import TokenType + + +class Redshift(Postgres): + time_format = "'YYYY-MM-DD HH:MI:SS'" + time_mapping = { + **Postgres.time_mapping, + "MON": "%b", + "HH": "%H", + } + + class Tokenizer(Postgres.Tokenizer): + ESCAPE = "\\" + + KEYWORDS = { + **Postgres.Tokenizer.KEYWORDS, + "GEOMETRY": TokenType.GEOMETRY, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "HLLSKETCH": TokenType.HLLSKETCH, + "SUPER": TokenType.SUPER, + "TIME": TokenType.TIMESTAMP, + "TIMETZ": TokenType.TIMESTAMPTZ, + "VARBYTE": TokenType.BINARY, + "SIMILAR TO": TokenType.SIMILAR_TO, + } + + class Generator(Postgres.Generator): + TYPE_MAPPING = { + **Postgres.Generator.TYPE_MAPPING, + exp.DataType.Type.BINARY: "VARBYTE", + exp.DataType.Type.INT: "INTEGER", + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 148dfb5..8d6ee78 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args): # case: [ , ] if second_arg.name not in ["0", "3", "9"]: - raise ValueError( - f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" - ) + raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9") if second_arg.name == "0": timescale = exp.UnixToTime.SECONDS diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 89c7ed5..a331191 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -65,12 +65,11 @@ class Spark(Hive): this=list_get(args, 0), start=exp.Sub( this=exp.Length(this=list_get(args, 0)), - expression=exp.Add( - this=list_get(args, 1), expression=exp.Literal.number(1) - ), + expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), ), length=list_get(args, 1), ), + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } class Generator(Hive.Generator): @@ -82,11 +81,7 @@ class Spark(Hive): } TRANSFORMS = { - **{ - k: v - for k, v in Hive.Generator.TRANSFORMS.items() - if k not in {exp.ArraySort} - }, + **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}}, exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), @@ -102,5 +97,5 @@ class Spark(Hive): HiveMap: _map_sql, } - def bitstring_sql(self, expression): - return f"X'{self.sql(expression, 'this')}'" + class Tokenizer(Hive.Tokenizer): + HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 6cf5022..cfdbe1b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType class SQLite(Dialect): class Tokenizer(Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] + HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] KEYWORDS = { **Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 805106c..9a6f7fe 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -8,3 +8,6 @@ class Trino(Presto): **Presto.Generator.TRANSFORMS, exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } + + class Tokenizer(Presto.Tokenizer): + HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 8eeb4e9..0567c12 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -115,13 +115,8 @@ class ChangeDistiller: for kept_source_node_id, kept_target_node_id in matching_set: source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] - if ( - not isinstance(source_node, LEAF_EXPRESSION_TYPES) - or source_node == target_node - ): - edit_script.extend( - self._generate_move_edits(source_node, target_node, matching_set) - ) + if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: + edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set)) edit_script.append(Keep(source_node, target_node)) else: edit_script.append(Update(source_node, target_node)) @@ -132,9 +127,7 @@ class ChangeDistiller: source_args = [id(e) for e in _expression_only_args(source)] target_args = [id(e) for e in _expression_only_args(target)] - args_lcs = set( - _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set) - ) + args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)) move_edits = [] for a in source_args: @@ -148,14 +141,10 @@ class ChangeDistiller: matching_set = leaves_matching_set.copy() ordered_unmatched_source_nodes = { - id(n[0]): None - for n in self._source.bfs() - if id(n[0]) in self._unmatched_source_nodes + id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes } ordered_unmatched_target_nodes = { - id(n[0]): None - for n in self._target.bfs() - if id(n[0]) in self._unmatched_target_nodes + id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes } for source_node_id in ordered_unmatched_source_nodes: @@ -169,18 +158,13 @@ class ChangeDistiller: max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) if max_leaves_num: common_leaves_num = sum( - 1 if s in source_leaf_ids and t in target_leaf_ids else 0 - for s, t in leaves_matching_set + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set ) leaf_similarity_score = common_leaves_num / max_leaves_num else: leaf_similarity_score = 0.0 - adjusted_t = ( - self.t - if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 - else 0.4 - ) + adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 if leaf_similarity_score >= 0.8 or ( leaf_similarity_score >= adjusted_t @@ -217,10 +201,7 @@ class ChangeDistiller: matching_set = set() while candidate_matchings: _, _, source_leaf, target_leaf = heappop(candidate_matchings) - if ( - id(source_leaf) in self._unmatched_source_nodes - and id(target_leaf) in self._unmatched_target_nodes - ): + if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes: matching_set.add((id(source_leaf), id(target_leaf))) self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_target_nodes.remove(id(target_leaf)) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index a437431..bca9f3e 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -3,11 +3,17 @@ import time from sqlglot import parse_one from sqlglot.executor.python import PythonExecutor -from sqlglot.optimizer import optimize +from sqlglot.optimizer import RULES, optimize +from sqlglot.optimizer.merge_derived_tables import merge_derived_tables from sqlglot.planner import Plan logger = logging.getLogger("sqlglot") +OPTIMIZER_RULES = list(RULES) + +# The executor needs isolated table selects +OPTIMIZER_RULES.remove(merge_derived_tables) + def execute(sql, schema, read=None): """ @@ -28,7 +34,7 @@ def execute(sql, schema, read=None): """ expression = parse_one(sql, read=read) now = time.time() - expression = optimize(expression, schema) + expression = optimize(expression, schema, rules=OPTIMIZER_RULES) logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) plan = Plan(expression) diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index 457bea7..d265a2c 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,9 +19,7 @@ class Context: env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables - self.range_readers = { - name: table.range_reader for name, table in self.tables.items() - } + self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 388a419..610aa4b 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -26,11 +26,7 @@ class PythonExecutor: while queue: node = queue.pop() context = self.context( - { - name: table - for dep in node.dependencies - for name, table in contexts[dep].tables.items() - } + {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} ) running.add(node) @@ -151,9 +147,7 @@ class PythonExecutor: return self.context({name: table for name in ctx.tables}) for name, join in step.joins.items(): - join_context = self.context( - {**join_context.tables, name: context.tables[name]} - ) + join_context = self.context({**join_context.tables, name: context.tables[name]}) if join.get("source_key"): table = self.hash_join(join, source, name, join_context) @@ -247,9 +241,7 @@ class PythonExecutor: if step.operands: source_table = context.tables[source] - operand_table = Table( - source_table.columns + self.table(step.operands).columns - ) + operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 6df49f7..80674cb 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -37,10 +37,7 @@ class Table: break lines.append( - " ".join( - str(row[column]).rjust(widths[column])[0 : widths[column]] - for column in self.columns - ) + " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) ) return "\n".join(lines) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7acc63d..b983bf9 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -47,10 +47,7 @@ class Expression(metaclass=_Expression): return hash( ( self.key, - tuple( - (k, tuple(v) if isinstance(v, list) else v) - for k, v in _norm_args(self).items() - ), + tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), ) ) @@ -116,9 +113,22 @@ class Expression(metaclass=_Expression): item.parent = parent return new + def append(self, arg_key, value): + """ + Appends value to arg_key if it's a list or sets it as a new list. + + Args: + arg_key (str): name of the list expression arg + value (Any): value to append to the list + """ + if not isinstance(self.args.get(arg_key), list): + self.args[arg_key] = [] + self.args[arg_key].append(value) + self._set_parent(arg_key, value) + def set(self, arg_key, value): """ - Sets `arg` to `value`. + Sets `arg_key` to `value`. Args: arg_key (str): name of the expression arg @@ -267,6 +277,14 @@ class Expression(metaclass=_Expression): expression = expression.this return expression + def unalias(self): + """ + Returns the inner expression if this is an Alias. + """ + if isinstance(self, Alias): + return self.this + return self + def unnest_operands(self): """ Returns unnested operands as a tuple. @@ -279,9 +297,7 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs( - prune=lambda n, p, *_: p and not isinstance(n, self.__class__) - ): + for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)): if not isinstance(node, self.__class__): yield node.unnest() if unnest else node @@ -314,9 +330,7 @@ class Expression(metaclass=_Expression): args = { k: ", ".join( - v.to_s(hide_missing=hide_missing, level=level + 1) - if hasattr(v, "to_s") - else str(v) + v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) for v in ensure_list(vs) if v is not None ) @@ -354,9 +368,7 @@ class Expression(metaclass=_Expression): new_node.parent = node.parent return new_node - replace_children( - new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs) - ) + replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)) return new_node def replace(self, expression): @@ -546,6 +558,10 @@ class BitString(Condition): pass +class HexString(Condition): + pass + + class Column(Condition): arg_types = {"this": True, "table": False} @@ -566,35 +582,44 @@ class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} -class AutoIncrementColumnConstraint(Expression): +class ColumnConstraintKind(Expression): pass -class CheckColumnConstraint(Expression): +class AutoIncrementColumnConstraint(ColumnConstraintKind): pass -class CollateColumnConstraint(Expression): +class CheckColumnConstraint(ColumnConstraintKind): pass -class CommentColumnConstraint(Expression): +class CollateColumnConstraint(ColumnConstraintKind): pass -class DefaultColumnConstraint(Expression): +class CommentColumnConstraint(ColumnConstraintKind): pass -class NotNullColumnConstraint(Expression): +class DefaultColumnConstraint(ColumnConstraintKind): pass -class PrimaryKeyColumnConstraint(Expression): +class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): + # this: True -> ALWAYS, this: False -> BY DEFAULT + arg_types = {"this": True, "expression": False} + + +class NotNullColumnConstraint(ColumnConstraintKind): pass -class UniqueColumnConstraint(Expression): +class PrimaryKeyColumnConstraint(ColumnConstraintKind): + pass + + +class UniqueColumnConstraint(ColumnConstraintKind): pass @@ -651,9 +676,7 @@ class Identifier(Expression): return bool(self.args.get("quoted")) def __eq__(self, other): - return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg( - other.this - ) + return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this) def __hash__(self): return hash((self.key, self.this.lower())) @@ -709,9 +732,7 @@ class Literal(Condition): def __eq__(self, other): return ( - isinstance(other, Literal) - and self.this == other.this - and self.args["is_string"] == other.args["is_string"] + isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] ) def __hash__(self): @@ -733,6 +754,7 @@ class Join(Expression): "side": False, "kind": False, "using": False, + "natural": False, } @property @@ -743,6 +765,10 @@ class Join(Expression): def side(self): return self.text("side").upper() + @property + def alias_or_name(self): + return self.this.alias_or_name + def on(self, *expressions, append=True, dialect=None, copy=True, **opts): """ Append to or set the ON expressions. @@ -873,10 +899,6 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} -class Table(Expression): - arg_types = {"this": True, "db": False, "catalog": False} - - class Tuple(Expression): arg_types = {"expressions": False} @@ -986,6 +1008,16 @@ QUERY_MODIFIERS = { } +class Table(Expression): + arg_types = { + "this": True, + "db": False, + "catalog": False, + "laterals": False, + "joins": False, + } + + class Union(Subqueryable, Expression): arg_types = { "with": False, @@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression): join.this.replace(join.this.subquery()) if join_type: - side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + if natural: + join.set("natural", True) if side: join.set("side", side.text) if kind: @@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression): properties_expression = None if properties: properties_str = " ".join( - [ - f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" - for k, v in properties.items() - ] + [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()] ) properties_expression = maybe_parse( properties_str, @@ -1654,6 +1685,7 @@ class DataType(Expression): DECIMAL = auto() BOOLEAN = auto() JSON = auto() + INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() DATE = auto() @@ -1662,15 +1694,19 @@ class DataType(Expression): MAP = auto() UUID = auto() GEOGRAPHY = auto() + GEOMETRY = auto() STRUCT = auto() NULLABLE = auto() + HLLSKETCH = auto() + SUPER = auto() + SERIAL = auto() + SMALLSERIAL = auto() + BIGSERIAL = auto() @classmethod def build(cls, dtype, **kwargs): return DataType( - this=dtype - if isinstance(dtype, DataType.Type) - else DataType.Type[dtype.upper()], + this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], **kwargs, ) @@ -1798,6 +1834,14 @@ class Like(Binary, Predicate): pass +class SimilarTo(Binary, Predicate): + pass + + +class Distance(Binary): + pass + + class LT(Binary, Predicate): pass @@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression): pass +class RespectNulls(Expression): + pass + + # Functions class Func(Condition): """ @@ -1924,9 +1972,7 @@ class Func(Condition): all_arg_keys = list(cls.arg_types) # If this function supports variable length argument treat the last argument as such. - non_var_len_arg_keys = ( - all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys - ) + non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys args_dict = {} arg_idx = 0 @@ -1944,9 +1990,7 @@ class Func(Condition): @classmethod def sql_names(cls): if cls is Func: - raise NotImplementedError( - "SQL name is only supported by concrete function implementations" - ) + raise NotImplementedError("SQL name is only supported by concrete function implementations") if not hasattr(cls, "_sql_names"): cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2178,6 +2222,10 @@ class Greatest(Func): is_var_len_args = True +class GroupConcat(Func): + arg_types = {"this": True, "separator": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -2274,6 +2322,10 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} +class ApproxQuantile(Quantile): + pass + + class Reduce(Func): arg_types = {"this": True, "initial": True, "merge": True, "finish": True} @@ -2306,8 +2358,10 @@ class Split(Func): arg_types = {"this": True, "expression": True} +# Start may be omitted in the case of postgres +# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 class Substring(Func): - arg_types = {"this": True, "start": True, "length": False} + arg_types = {"this": True, "start": False, "length": False} class StrPosition(Func): @@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func): pass +class Trim(Func): + arg_types = { + "this": True, + "position": False, + "expression": False, + "collation": False, + } + + class TsOrDsAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -2455,9 +2518,7 @@ def _all_functions(): obj for _, obj in inspect.getmembers( sys.modules[__name__], - lambda obj: inspect.isclass(obj) - and issubclass(obj, Func) - and obj not in (AggFunc, Anonymous, Func), + lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func), ) ] @@ -2633,9 +2694,7 @@ def _apply_conjunction_builder( def _combine(expressions, operator, dialect=None, **opts): - expressions = [ - condition(expression, dialect=dialect, **opts) for expression in expressions - ] + expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] this = expressions[0] if expressions[1:]: this = _wrap_operator(this) @@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None): quoted = not re.match(SAFE_IDENTIFIER_RE, alias) identifier = Identifier(this=alias, quoted=quoted) else: - raise ValueError( - f"Alias needs to be a string or an Identifier, got: {alias.__class__}" - ) + raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}") return identifier diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 793cff0..a445178 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -41,6 +41,8 @@ class Generator: max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 + leading_comma (bool): if the the comma is leading or trailing in select statements + Default: False """ TRANSFORMS = { @@ -108,6 +110,7 @@ class Generator: "_indent", "_replace_backslash", "_escaped_quote_end", + "_leading_comma", ) def __init__( @@ -131,6 +134,7 @@ class Generator: unsupported_level=ErrorLevel.WARN, null_ordering=None, max_unsupported=3, + leading_comma=False, ): import sqlglot @@ -157,6 +161,7 @@ class Generator: self._indent = indent self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end + self._leading_comma = leading_comma def generate(self, expression): """ @@ -178,9 +183,7 @@ class Generator: for msg in self.unsupported_messages: logger.warning(msg) elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError( - concat_errors(self.unsupported_messages, self.max_unsupported) - ) + raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) return sql @@ -197,9 +200,7 @@ class Generator: def wrap(self, expression): this_sql = self.indent( - self.sql(expression) - if isinstance(expression, (exp.Select, exp.Union)) - else self.sql(expression, "this"), + self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), level=1, pad=0, ) @@ -251,9 +252,7 @@ class Generator: return transform if not isinstance(expression, exp.Expression): - raise ValueError( - f"Expected an Expression. Received {type(expression)}: {expression}" - ) + raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") exp_handler_name = f"{expression.key}_sql" if hasattr(self, exp_handler_name): @@ -276,11 +275,7 @@ class Generator: lazy = " LAZY" if expression.args.get("lazy") else "" table = self.sql(expression, "this") options = expression.args.get("options") - options = ( - f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" - if options - else "" - ) + options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else "" sql = self.sql(expression, "expression") sql = f" AS{self.sep()}{sql}" if sql else "" sql = f"CACHE{lazy} TABLE {table}{options}{sql}" @@ -306,9 +301,7 @@ class Generator: def columndef_sql(self, expression): column = self.sql(expression, "this") kind = self.sql(expression, "kind") - constraints = self.expressions( - expression, key="constraints", sep=" ", flat=True - ) + constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) if not constraints: return f"{column} {kind}" @@ -338,6 +331,9 @@ class Generator: default = self.sql(expression, "this") return f"DEFAULT {default}" + def generatedasidentitycolumnconstraint_sql(self, expression): + return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" + def notnullcolumnconstraint_sql(self, _): return "NOT NULL" @@ -384,7 +380,10 @@ class Generator: return f"{alias}{columns}" def bitstring_sql(self, expression): - return f"b'{self.sql(expression, 'this')}'" + return self.sql(expression, "this") + + def hexstring_sql(self, expression): + return self.sql(expression, "this") def datatype_sql(self, expression): type_value = expression.this @@ -452,10 +451,7 @@ class Generator: def partition_sql(self, expression): keys = csv( - *[ - f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] - for k, v in expression.args.get("this") - ] + *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] ) return f"PARTITION({keys})" @@ -470,9 +466,9 @@ class Generator: elif p_class in self.WITH_PROPERTIES: with_properties.append(p) - return self.root_properties( - exp.Properties(expressions=root_properties) - ) + self.with_properties(exp.Properties(expressions=with_properties)) + return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( + exp.Properties(expressions=with_properties) + ) def root_properties(self, properties): if properties.expressions: @@ -508,11 +504,7 @@ class Generator: kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" this = self.sql(expression, "this") exists = " IF EXISTS " if expression.args.get("exists") else " " - partition_sql = ( - self.sql(expression, "partition") - if expression.args.get("partition") - else "" - ) + partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" @@ -531,7 +523,7 @@ class Generator: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" def table_sql(self, expression): - return ".".join( + table = ".".join( part for part in [ self.sql(expression, "catalog"), @@ -541,6 +533,10 @@ class Generator: if part ) + laterals = self.expressions(expression, key="laterals", sep="") + joins = self.expressions(expression, key="joins", sep="") + return f"{table}{laterals}{joins}" + def tablesample_sql(self, expression): if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): this = self.sql(expression.this, "this") @@ -586,11 +582,7 @@ class Generator: def group_sql(self, expression): group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) - grouping_sets = ( - f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" - if grouping_sets - else "" - ) + grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" cube = self.expressions(expression, key="cube", indent=False) cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" rollup = self.expressions(expression, key="rollup", indent=False) @@ -603,7 +595,16 @@ class Generator: def join_sql(self, expression): op_sql = self.seg( - " ".join(op for op in (expression.side, expression.kind, "JOIN") if op) + " ".join( + op + for op in ( + "NATURAL" if expression.args.get("natural") else None, + expression.side, + expression.kind, + "JOIN", + ) + if op + ) ) on_sql = self.sql(expression, "on") using = expression.args.get("using") @@ -630,9 +631,9 @@ class Generator: def lateral_sql(self, expression): this = self.sql(expression, "this") - op_sql = self.seg( - f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" - ) + if isinstance(expression.this, exp.Subquery): + return f"LATERAL{self.sep()}{this}" + op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") alias = expression.args["alias"] table = alias.name table = f" {table}" if table else table @@ -688,21 +689,13 @@ class Generator: sort_order = " DESC" if desc else "" nulls_sort_change = "" - if nulls_first and ( - (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last - ): + if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): nulls_sort_change = " NULLS FIRST" - elif ( - nulls_last - and ((asc and nulls_are_small) or (desc and nulls_are_large)) - and not nulls_are_last - ): + elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: nulls_sort_change = " NULLS LAST" if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported( - "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" - ) + self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") nulls_sort_change = "" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" @@ -798,14 +791,20 @@ class Generator: def window_sql(self, expression): this = self.sql(expression, "this") + partition = self.expressions(expression, key="partition_by", flat=True) partition = f"PARTITION BY {partition}" if partition else "" + order = expression.args.get("order") order_sql = self.order_sql(order, flat=True) if order else "" + partition_sql = partition + " " if partition and order else partition + spec = expression.args.get("spec") spec_sql = " " + self.window_spec_sql(spec) if spec else "" + alias = self.sql(expression, "alias") + if expression.arg_key == "window": this = this = f"{self.seg('WINDOW')} {this} AS" else: @@ -818,13 +817,8 @@ class Generator: def window_spec_sql(self, expression): kind = self.sql(expression, "kind") - start = csv( - self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " - ) - end = ( - csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") - or "CURRENT ROW" - ) + start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") + end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" return f"{kind} BETWEEN {start} AND {end}" def withingroup_sql(self, expression): @@ -879,6 +873,17 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" + def trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + + if trim_type == "LEADING": + return f"LTRIM({target})" + elif trim_type == "TRAILING": + return f"RTRIM({target})" + else: + return f"TRIM({target})" + def check_sql(self, expression): this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -898,9 +903,7 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql( - exp.Case(ifs=[expression], default=expression.args.get("false")) - ) + return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def in_sql(self, expression): query = expression.args.get("query") @@ -917,7 +920,9 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression): - return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}" + unit = self.sql(expression, "unit") + unit = f" {unit}" if unit else "" + return f"INTERVAL {self.sql(expression, 'this')}{unit}" def reference_sql(self, expression): this = self.sql(expression, "this") @@ -925,9 +930,7 @@ class Generator: return f"REFERENCES {this}({expressions})" def anonymous_sql(self, expression): - args = self.indent( - self.expressions(expression, flat=True), skip_first=True, skip_last=True - ) + args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" def paren_sql(self, expression): @@ -1006,6 +1009,9 @@ class Generator: def ignorenulls_sql(self, expression): return f"{self.sql(expression, 'this')} IGNORE NULLS" + def respectnulls_sql(self, expression): + return f"{self.sql(expression, 'this')} RESPECT NULLS" + def intdiv_sql(self, expression): return self.sql( exp.Cast( @@ -1023,6 +1029,9 @@ class Generator: def div_sql(self, expression): return self.binary(expression, "/") + def distance_sql(self, expression): + return self.binary(expression, "<->") + def dot_sql(self, expression): return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" @@ -1047,6 +1056,9 @@ class Generator: def like_sql(self, expression): return self.binary(expression, "LIKE") + def similarto_sql(self, expression): + return self.binary(expression, "SIMILAR TO") + def lt_sql(self, expression): return self.binary(expression, "<") @@ -1069,14 +1081,10 @@ class Generator: return self.binary(expression, "-") def trycast_sql(self, expression): - return ( - f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - ) + return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" def binary(self, expression, op): - return ( - f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" - ) + return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" def function_fallback_sql(self, expression): args = [] @@ -1089,9 +1097,7 @@ class Generator: return f"{self.normalize_func(expression.sql_name())}({args_str})" def format_time(self, expression): - return format_time( - self.sql(expression, "format"), self.time_mapping, self.time_trie - ) + return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): expressions = expression.args.get(key or "expressions") @@ -1102,7 +1108,14 @@ class Generator: if flat: return sep.join(self.sql(e) for e in expressions) - expressions = self.sep(sep).join(self.sql(e) for e in expressions) + sql = (self.sql(e) for e in expressions) + # the only time leading_comma changes the output is if pretty print is enabled + if self._leading_comma and self.pretty: + pad = " " * self.pad + expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql)) + else: + expressions = self.sep(sep).join(sql) + if indent: return self.indent(expressions, skip_first=False) return expressions @@ -1116,9 +1129,7 @@ class Generator: def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) - return self.query_modifiers( - expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" - ) + return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index a4c4cc2..d1146ca 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,2 +1,2 @@ -from sqlglot.optimizer.optimizer import optimize +from sqlglot.optimizer.optimizer import RULES, optimize from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index c2e021e..e060739 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -13,9 +13,7 @@ def isolate_table_selects(expression): continue if not isinstance(source.parent, exp.Alias): - raise OptimizeError( - "Tables require an alias. Run qualify_tables optimization." - ) + raise OptimizeError("Tables require an alias. Run qualify_tables optimization.") parent = source.parent diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_derived_tables.py new file mode 100644 index 0000000..8b161fb --- /dev/null +++ b/sqlglot/optimizer/merge_derived_tables.py @@ -0,0 +1,232 @@ +from collections import defaultdict + +from sqlglot import expressions as exp +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def merge_derived_tables(expression): + """ + Rewrite sqlglot AST to merge derived tables into the outer query. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") + >>> merge_derived_tables(expression).sql() + 'SELECT x.a FROM x' + + Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for outer_scope in traverse_scope(expression): + for subquery in outer_scope.derived_tables: + inner_select = subquery.unnest() + if ( + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and _mergeable(inner_select) + ): + alias = subquery.alias_or_name + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + inner_scope = outer_scope.sources[alias] + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, subquery) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def _mergeable(inner_select): + """ + Return True if `inner_select` can be merged into outer query. + + Args: + inner_select (exp.Select) + Returns: + bool: True if can be merged + """ + return ( + isinstance(inner_select, exp.Select) + and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) + and inner_select.args.get("from") + and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + ) + + +def _rename_inner_sources(outer_scope, inner_scope, alias): + """ + Renames any sources in the inner query that conflict with names in the outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + taken = set(outer_scope.selected_sources) + conflicts = taken.intersection(set(inner_scope.selected_sources)) + conflicts = conflicts - {alias} + + for conflict in conflicts: + new_name = _find_new_name(taken, conflict) + + source, _ = inner_scope.selected_sources[conflict] + new_alias = exp.to_identifier(new_name) + + if isinstance(source, exp.Subquery): + source.set("alias", exp.TableAlias(this=new_alias)) + elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): + source.parent.set("alias", new_alias) + elif isinstance(source, exp.Table): + source.replace(exp.alias_(source.copy(), new_alias)) + + for column in inner_scope.source_columns(conflict): + column.set("table", exp.to_identifier(new_name)) + + inner_scope.rename_source(conflict, new_name) + + +def _find_new_name(taken, base): + """ + Searches for a new source name. + + Args: + taken (set[str]): set of taken names + base (str): base name to alter + """ + i = 2 + new = f"{base}_{i}" + while new in taken: + i += 1 + new = f"{base}_{i}" + return new + + +def _merge_from(outer_scope, inner_scope, subquery): + """ + Merge FROM clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + subquery (exp.Subquery) + """ + new_subquery = inner_scope.expression.args.get("from").expressions[0] + subquery.replace(new_subquery) + outer_scope.remove_source(subquery.alias_or_name) + outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + + +def _merge_joins(outer_scope, inner_scope, from_or_join): + """ + Merge JOIN clauses of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + + new_joins = [] + comma_joins = inner_scope.expression.args.get("from").expressions[1:] + for subquery in comma_joins: + new_joins.append(exp.Join(this=subquery, kind="CROSS")) + outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) + + joins = inner_scope.expression.args.get("joins") or [] + for join in joins: + new_joins.append(join) + outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) + + if new_joins: + outer_joins = outer_scope.expression.args.get("joins", []) + + # Maintain the join order + if isinstance(from_or_join, exp.From): + position = 0 + else: + position = outer_joins.index(from_or_join) + 1 + outer_joins[position:position] = new_joins + + outer_scope.expression.set("joins", outer_joins) + + +def _merge_expressions(outer_scope, inner_scope, alias): + """ + Merge projections of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + # Collect all columns that for the alias of the inner query + outer_columns = defaultdict(list) + for column in outer_scope.columns: + if column.table == alias: + outer_columns[column.name].append(column) + + # Replace columns with the projection expression in the inner query + for expression in inner_scope.expression.expressions: + projection_name = expression.alias_or_name + if not projection_name: + continue + columns_to_replace = outer_columns.get(projection_name, []) + for column in columns_to_replace: + column.replace(expression.unalias()) + + +def _merge_where(outer_scope, inner_scope, from_or_join): + """ + Merge WHERE clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + where = inner_scope.expression.args.get("where") + if not where or not where.this: + return + + if isinstance(from_or_join, exp.Join) and from_or_join.side: + # Merge predicates from an outer join to the ON clause + from_or_join.on(where.this, copy=False) + from_or_join.set("on", simplify(from_or_join.args.get("on"))) + else: + outer_scope.expression.where(where.this, copy=False) + outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) + + +def _merge_order(outer_scope, inner_scope): + """ + Merge ORDER clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + """ + if ( + any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + or len(outer_scope.selected_sources) != 1 + or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) + ): + return + + outer_scope.expression.set("order", inner_scope.expression.args.get("order")) diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 2c9f89c..ab30d7a 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128): """ expression = simplify(expression) - expression = while_changing( - expression, lambda e: distributive_law(e, dnf, max_distance) - ) + expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance)) return simplify(expression) def normalized(expression, dnf=False): ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) - return not any( - connector.find_ancestor(ancestor) for connector in expression.find_all(root) - ) + return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) def normalization_distance(expression, dnf=False): @@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False): Returns: int: difference """ - return sum(_predicate_lengths(expression, dnf)) - ( - len(list(expression.find_all(exp.Connector))) + 1 - ) + return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) def _predicate_lengths(expression, dnf): @@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf): left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - x = [ - a + b - for a in _predicate_lengths(left, dnf) - for b in _predicate_lengths(right, dnf) - ] + x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] return x return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) @@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance): to_func = exp.and_ if to_exp == exp.And else exp.or_ if isinstance(a, to_exp) and isinstance(b, to_exp): - if len(tuple(a.find_all(exp.Connector))) > len( - tuple(b.find_all(exp.Connector)) - ): + if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): return _distribute(a, b, from_func, to_func) return _distribute(b, a, from_func, to_func) if isinstance(a, to_exp): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 40e4ab1..0c74e36 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -68,8 +68,4 @@ def normalize(expression): def other_table_names(join, exclude): - return [ - name - for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) - if name != exclude - ] + return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c03fe3c..c8c2403 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,6 +1,7 @@ from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.merge_derived_tables import merge_derived_tables from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates @@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries +RULES = ( + qualify_tables, + isolate_table_selects, + qualify_columns, + pushdown_projections, + normalize, + unnest_subqueries, + expand_multi_table_selects, + pushdown_predicates, + optimize_joins, + eliminate_subqueries, + merge_derived_tables, + quote_identities, +) -def optimize(expression, schema=None, db=None, catalog=None): + +def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs): """ Rewrite a sqlglot AST into an optimized form. @@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None): 3. {catalog: {db: {table: {col: type}}}} db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement + rules (list): sequence of optimizer rules to use + **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression """ + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() - expression = qualify_tables(expression, db=db, catalog=catalog) - expression = isolate_table_selects(expression) - expression = qualify_columns(expression, schema) - expression = pushdown_projections(expression) - expression = normalize(expression) - expression = unnest_subqueries(expression) - expression = expand_multi_table_selects(expression) - expression = pushdown_predicates(expression) - expression = optimize_joins(expression) - expression = eliminate_subqueries(expression) - expression = quote_identities(expression) + for rule in rules: + + # Find any additional rule parameters, beyond `expression` + rule_params = rule.__code__.co_varnames + rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + + expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index e757322..a070d70 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -42,11 +42,7 @@ def pushdown(condition, sources): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list( - condition.flatten() - if isinstance(condition, exp.And if cnf_like else exp.Or) - else [condition] - ) + predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: pushdown_cnf(predicates, sources) @@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = ( - exp.and_(predicate_condition, condition) - if predicate_condition - else condition - ) + predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) - if table in conditions - else predicate_condition + exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition ) for name, node in nodes.items(): @@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope): def nodes_for_predicate(predicate, sources): nodes = {} tables = exp.column_table_names(predicate) - where_condition = isinstance( - predicate.find_ancestor(exp.Join, exp.Where), exp.Where - ) + where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) for table in tables: node, source = sources.get(table) or (None, None) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 394f49e..0bb947a 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -226,9 +226,7 @@ def _expand_stars(scope, resolver): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif isinstance(expression, exp.Column) and isinstance( - expression.this, exp.Star - ): + elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -245,9 +243,7 @@ def _expand_stars(scope, resolver): if name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table) - new_selections.append( - alias(column, alias_) if alias_ != name else column - ) + new_selections.append(alias(column, alias_) if alias_ != name else column) scope.expression.set("expressions", new_selections) @@ -280,9 +276,7 @@ def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = [] - for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.selects, scope.outer_column_list) - ): + for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): if isinstance(selection, exp.Column): # convoluted setter because a simple selection.replace(alias) would require a copy alias_ = alias(exp.column(""), alias=selection.name) @@ -302,11 +296,7 @@ def _qualify_outputs(scope): def _check_unknown_tables(scope): - if ( - scope.external_columns - and not scope.is_unnest - and not scope.is_correlated_subquery - ): + if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery: raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") @@ -334,20 +324,14 @@ class _Resolver: (str) table name """ if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns( - self._get_all_source_columns() - ) + self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) return self._unambiguous_columns.get(column_name) @property def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set( - column - for columns in self._get_all_source_columns().values() - for column in columns - ) + self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) return self._all_columns def get_source_columns(self, name): @@ -369,9 +353,7 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: - self._source_columns = { - k: self.get_source_columns(k) for k in self.scope.selected_sources - } + self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} return self._source_columns def _get_unambiguous_columns(self, source_columns): @@ -389,9 +371,7 @@ class _Resolver: source_columns = list(source_columns.items()) first_table, first_columns = source_columns[0] - unambiguous_columns = { - col: first_table for col in self._find_unique_columns(first_columns) - } + unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} all_columns = set(unambiguous_columns) for table, columns in source_columns[1:]: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 9f8b9f5..30e93ba 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None): for derived_table in scope.ctes + scope.derived_tables: if not derived_table.args.get("alias"): alias_ = f"_q_{next(sequence)}" - derived_table.set( - "alias", exp.TableAlias(this=exp.to_identifier(alias_)) - ) + derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) for source in scope.sources.values(): diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 9968108..1761228 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -57,9 +57,7 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): - raise ValueError( - f"Schema doesn't support {forbidden}. Received: {table.sql()}" - ) + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index f6f59e8..e816e10 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -104,9 +104,7 @@ class Scope: elif isinstance(node, exp.CTE): self._ctes.append(node) prune = True - elif isinstance(node, exp.Subquery) and isinstance( - parent, (exp.From, exp.Join) - ): + elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): self._derived_tables.append(node) prune = True elif isinstance(node, exp.Subqueryable): @@ -195,20 +193,14 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [ - column - for scope in self.subquery_scopes - for column in scope.external_columns - ] + external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] named_outputs = {e.alias_or_name for e in self.expression.expressions} self._columns = [ c for c in columns + external_columns - if not ( - c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs - ) + if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) ] return self._columns @@ -229,9 +221,7 @@ class Scope: for table in self.tables: referenced_names.append( ( - table.parent.alias - if isinstance(table.parent, exp.Alias) - else table.name, + table.parent.alias if isinstance(table.parent, exp.Alias) else table.name, table, ) ) @@ -274,9 +264,7 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [ - c for c in self.columns if c.table not in self.selected_sources - ] + self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns def source_columns(self, source_name): @@ -310,6 +298,16 @@ class Scope: columns = self.sources.pop(old_name or "", []) self.sources[new_name] = columns + def add_source(self, name, source): + """Add a source to this scope""" + self.sources[name] = source + self.clear_cache() + + def remove_source(self, name): + """Remove a source from this scope""" + self.sources.pop(name, None) + self.clear_cache() + def traverse_scope(expression): """ @@ -334,7 +332,7 @@ def traverse_scope(expression): Args: expression (exp.Expression): expression to traverse Returns: - List[Scope]: scope instances + list[Scope]: scope instances """ return list(_traverse_scope(Scope(expression))) @@ -356,9 +354,7 @@ def _traverse_scope(scope): def _traverse_select(scope): yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) yield from _traverse_subqueries(scope) - yield from _traverse_derived_tables( - scope.derived_tables, scope, ScopeType.DERIVED_TABLE - ) + yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) _add_table_sources(scope) @@ -367,15 +363,11 @@ def _traverse_union(scope): # The last scope to be yield should be the top most scope left = None - for left in _traverse_scope( - scope.branch(scope.expression.left, scope_type=ScopeType.UNION) - ): + for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): yield left right = None - for right in _traverse_scope( - scope.branch(scope.expression.right, scope_type=ScopeType.UNION) - ): + for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right scope.union = (left, right) @@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): for derived_table in derived_tables: for child_scope in _traverse_scope( scope.branch( - derived_table - if isinstance(derived_table, (exp.Unnest, exp.Lateral)) - else derived_table.this, + derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, add_sources=sources if scope_type == ScopeType.CTE else None, outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UNNEST - if isinstance(derived_table, exp.Unnest) - else scope_type, + scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type, ) ): yield child_scope @@ -430,9 +418,7 @@ def _add_table_sources(scope): def _traverse_subqueries(scope): for subquery in scope.subqueries: top = None - for child_scope in _traverse_scope( - scope.branch(subquery, scope_type=ScopeType.SUBQUERY) - ): + for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): yield child_scope top = child_scope scope.subquery_scopes.append(top) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6771153..319e6b6 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -188,9 +188,7 @@ def absorb_and_eliminate(expression): aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) elif is_complement(b, ab): ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) - elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( - a.flatten() - ): + elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): a.replace(exp.FALSE if kind == exp.And else exp.TRUE) elif isinstance(b, kind): # eliminate @@ -227,9 +225,7 @@ def simplify_literals(expression): operands.append(a) if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) + return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 55c81c5..11c6eba 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence): return if isinstance(predicate, exp.Binary): - key = ( - predicate.right - if any(node is column for node, *_ in predicate.left.walk()) - else predicate.left - ) + key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left else: return @@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence): # if the value of the subquery is not an agg or a key, we need to collect it into an array # so that it can be grouped if not value.find(exp.AggFunc) and value.this not in group_by: - select.select( - f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False - ) + select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence): else: parent_predicate = _replace(parent_predicate, "TRUE") elif isinstance(parent_predicate, exp.All): - parent_predicate = _replace( - parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" - ) + parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") elif isinstance(parent_predicate, exp.Any): if value.this in group_by: parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") else: - parent_predicate = _replace( - parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" - ) + parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") elif isinstance(parent_predicate, exp.In): if value.this in group_by: parent_predicate = _replace(parent_predicate, f"{other} = {alias}") @@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace( - parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" - ) + parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9396c50..f46bafe 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -78,6 +78,7 @@ class Parser: TokenType.TEXT, TokenType.BINARY, TokenType.JSON, + TokenType.INTERVAL, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.DATETIME, @@ -85,6 +86,12 @@ class Parser: TokenType.DECIMAL, TokenType.UUID, TokenType.GEOGRAPHY, + TokenType.GEOMETRY, + TokenType.HLLSKETCH, + TokenType.SUPER, + TokenType.SERIAL, + TokenType.SMALLSERIAL, + TokenType.BIGSERIAL, *NESTED_TYPE_TOKENS, } @@ -100,13 +107,14 @@ class Parser: ID_VAR_TOKENS = { TokenType.VAR, TokenType.ALTER, + TokenType.ALWAYS, TokenType.BEGIN, + TokenType.BOTH, TokenType.BUCKET, TokenType.CACHE, TokenType.COLLATE, TokenType.COMMIT, TokenType.CONSTRAINT, - TokenType.CONVERT, TokenType.DEFAULT, TokenType.DELETE, TokenType.ENGINE, @@ -115,14 +123,19 @@ class Parser: TokenType.FALSE, TokenType.FIRST, TokenType.FOLLOWING, + TokenType.FOR, TokenType.FORMAT, TokenType.FUNCTION, + TokenType.GENERATED, + TokenType.IDENTITY, TokenType.IF, TokenType.INDEX, TokenType.ISNULL, TokenType.INTERVAL, TokenType.LAZY, + TokenType.LEADING, TokenType.LOCATION, + TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, TokenType.OPTIMIZE, @@ -141,6 +154,7 @@ class Parser: TokenType.TABLE_FORMAT, TokenType.TEMPORARY, TokenType.TOP, + TokenType.TRAILING, TokenType.TRUNCATE, TokenType.TRUE, TokenType.UNBOUNDED, @@ -150,18 +164,15 @@ class Parser: *TYPE_TOKENS, } - CASTS = { - TokenType.CAST, - TokenType.TRY_CAST, - } + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL} + + TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { - TokenType.CONVERT, TokenType.CURRENT_DATE, TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIME, - TokenType.EXTRACT, TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, @@ -178,7 +189,6 @@ class Parser: TokenType.DATETIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, - *CASTS, *NESTED_TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -215,6 +225,7 @@ class Parser: FACTOR = { TokenType.DIV: exp.IntDiv, + TokenType.LR_ARROW: exp.Distance, TokenType.SLASH: exp.Div, TokenType.STAR: exp.Mul, } @@ -299,14 +310,13 @@ class Parser: PRIMARY_PARSERS = { TokenType.STRING: lambda _, token: exp.Literal.string(token.text), TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), - TokenType.STAR: lambda self, _: exp.Star( - **{"except": self._parse_except(), "replace": self._parse_replace()} - ), + TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), TokenType.NULL: lambda *_: exp.Null(), TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), + TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), TokenType.INTRODUCER: lambda self, token: self.expression( exp.Introducer, this=token.text, @@ -319,13 +329,16 @@ class Parser: TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: lambda self, this: self._parse_escape( - self.expression(exp.Like, this=this, expression=self._parse_type()) + self.expression(exp.Like, this=this, expression=self._parse_bitwise()) ), TokenType.ILIKE: lambda self, this: self._parse_escape( - self.expression(exp.ILike, this=this, expression=self._parse_type()) + self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) ), TokenType.RLIKE: lambda self, this: self.expression( - exp.RegexpLike, this=this, expression=self._parse_type() + exp.RegexpLike, this=this, expression=self._parse_bitwise() + ), + TokenType.SIMILAR_TO: lambda self, this: self.expression( + exp.SimilarTo, this=this, expression=self._parse_bitwise() ), } @@ -363,28 +376,21 @@ class Parser: } FUNCTION_PARSERS = { - TokenType.CONVERT: lambda self, _: self._parse_convert(), - TokenType.EXTRACT: lambda self, _: self._parse_extract(), - **{ - token_type: lambda self, token_type: self._parse_cast( - self.STRICT_CAST and token_type == TokenType.CAST - ) - for token_type in CASTS - }, + "CONVERT": lambda self: self._parse_convert(), + "EXTRACT": lambda self: self._parse_extract(), + "SUBSTRING": lambda self: self._parse_substring(), + "TRIM": lambda self: self._parse_trim(), + "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "TRY_CAST": lambda self: self._parse_cast(False), } QUERY_MODIFIER_PARSERS = { - "laterals": lambda self: self._parse_laterals(), - "joins": lambda self: self._parse_joins(), "where": lambda self: self._parse_where(), "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) - and self._parse_window(self._parse_id_var(), alias=True), - "distribute": lambda self: self._parse_sort( - TokenType.DISTRIBUTE_BY, exp.Distribute - ), + "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), + "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "order": lambda self: self._parse_order(), @@ -392,6 +398,8 @@ class Parser: "offset": lambda self: self._parse_offset(), } + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) + CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} STRICT_CAST = True @@ -457,9 +465,7 @@ class Parser: Returns the list of syntax trees (:class:`~sqlglot.expressions.Expression`). """ - return self._parse( - parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql - ) + return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) def parse_into(self, expression_types, raw_tokens, sql=None): for expression_type in ensure_list(expression_types): @@ -532,21 +538,13 @@ class Parser: for k in expression.args: if k not in expression.arg_types: - self.raise_error( - f"Unexpected keyword: '{k}' for {expression.__class__}" - ) + self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}") for k, mandatory in expression.arg_types.items(): v = expression.args.get(k) if mandatory and (v is None or (isinstance(v, list) and not v)): - self.raise_error( - f"Required keyword: '{k}' missing for {expression.__class__}" - ) + self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - if ( - args - and len(args) > len(expression.arg_types) - and not expression.is_var_len_args - ): + if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args: self.raise_error( f"The number of provided arguments ({len(args)}) is greater than " f"the maximum number of supported arguments ({len(expression.arg_types)})" @@ -594,11 +592,7 @@ class Parser: ) expression = self._parse_expression() - expression = ( - self._parse_set_operations(expression) - if expression - else self._parse_select() - ) + expression = self._parse_set_operations(expression) if expression else self._parse_select() self._parse_query_modifiers(expression) return expression @@ -618,11 +612,7 @@ class Parser: ) def _parse_exists(self, not_=False): - return ( - self._match(TokenType.IF) - and (not not_ or self._match(TokenType.NOT)) - and self._match(TokenType.EXISTS) - ) + return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) @@ -647,11 +637,9 @@ class Parser: this = self._parse_index() elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): this = self._parse_table(schema=True) - properties = self._parse_properties( - this if isinstance(this, exp.Schema) else None - ) + properties = self._parse_properties(this if isinstance(this, exp.Schema) else None) if self._match(TokenType.ALIAS): - expression = self._parse_select() + expression = self._parse_select(nested=True) return self.expression( exp.Create, @@ -682,9 +670,7 @@ class Parser: if schema and not isinstance(value, exp.Schema): columns = {v.name.upper() for v in value.expressions} partitions = [ - expression - for expression in schema.expressions - if expression.this.name.upper() in columns + expression for expression in schema.expressions if expression.this.name.upper() in columns ] schema.set( "expressions", @@ -811,7 +797,7 @@ class Parser: this=self._parse_table(schema=True), exists=self._parse_exists(), partition=self._parse_partition(), - expression=self._parse_select(), + expression=self._parse_select(nested=True), overwrite=overwrite, ) @@ -829,8 +815,7 @@ class Parser: exp.Update, **{ "this": self._parse_table(schema=True), - "expressions": self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), + "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), "where": self._parse_where(), }, @@ -865,7 +850,7 @@ class Parser: this=table, lazy=lazy, options=options, - expression=self._parse_select(), + expression=self._parse_select(nested=True), ) def _parse_partition(self): @@ -894,9 +879,7 @@ class Parser: self._match_r_paren() return self.expression(exp.Tuple, expressions=expressions) - def _parse_select(self, table=None): - index = self._index - + def _parse_select(self, nested=False, table=False): if self._match(TokenType.SELECT): hint = self._parse_hint() all_ = self._match(TokenType.ALL) @@ -912,9 +895,7 @@ class Parser: self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv( - lambda: self._parse_annotation(self._parse_expression()) - ) + expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) this = self.expression( exp.Select, @@ -960,19 +941,13 @@ class Parser: ) else: self.raise_error(f"{this.key} does not support CTE") - elif self._match(TokenType.L_PAREN): - this = self._parse_table() if table else self._parse_select() - - if this: - self._parse_query_modifiers(this) - self._match_r_paren() - this = self._parse_subquery(this) - else: - self._retreat(index) + elif (table or nested) and self._match(TokenType.L_PAREN): + this = self._parse_table() if table else self._parse_select(nested=True) + self._parse_query_modifiers(this) + self._match_r_paren() + this = self._parse_subquery(this) elif self._match(TokenType.VALUES): - this = self.expression( - exp.Values, expressions=self._parse_csv(self._parse_value) - ) + this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value)) alias = self._parse_table_alias() if alias: this = self.expression(exp.Subquery, this=this, alias=alias) @@ -1001,7 +976,7 @@ class Parser: def _parse_table_alias(self): any_token = self._match(TokenType.ALIAS) - alias = self._parse_id_var(any_token) + alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS) columns = None if self._match(TokenType.L_PAREN): @@ -1021,9 +996,24 @@ class Parser: return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) def _parse_query_modifiers(self, this): - if not isinstance(this, (exp.Subquery, exp.Subqueryable)): + if not isinstance(this, self.MODIFIABLES): return + table = isinstance(this, exp.Table) + + while True: + lateral = self._parse_lateral() + join = self._parse_join() + comma = None if table else self._match(TokenType.COMMA) + if lateral: + this.append("laterals", lateral) + if join: + this.append("joins", join) + if comma: + this.args["from"].append("expressions", self._parse_table()) + if not (lateral or join or comma): + break + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): expression = parser(self) @@ -1032,9 +1022,7 @@ class Parser: def _parse_annotation(self, expression): if self._match(TokenType.ANNOTATION): - return self.expression( - exp.Annotation, this=self._prev.text, expression=expression - ) + return self.expression(exp.Annotation, this=self._prev.text, expression=expression) return expression @@ -1052,16 +1040,16 @@ class Parser: return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) - def _parse_laterals(self): - return self._parse_all(self._parse_lateral) - def _parse_lateral(self): if not self._match(TokenType.LATERAL): return None - if not self._match(TokenType.VIEW): - self.raise_error("Expected VIEW after LATERAL") + subquery = self._parse_select(table=True) + if subquery: + return self.expression(exp.Lateral, this=subquery) + + self._match(TokenType.VIEW) outer = self._match(TokenType.OUTER) return self.expression( @@ -1071,31 +1059,27 @@ class Parser: alias=self.expression( exp.TableAlias, this=self._parse_id_var(any_token=False), - columns=( - self._parse_csv(self._parse_id_var) - if self._match(TokenType.ALIAS) - else None - ), + columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None), ), ) - def _parse_joins(self): - return self._parse_all(self._parse_join) - def _parse_join_side_and_kind(self): return ( + self._match(TokenType.NATURAL) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) def _parse_join(self): - side, kind = self._parse_join_side_and_kind() + natural, side, kind = self._parse_join_side_and_kind() if not self._match(TokenType.JOIN): return None kwargs = {"this": self._parse_table()} + if natural: + kwargs["natural"] = True if side: kwargs["side"] = side.text if kind: @@ -1120,6 +1104,11 @@ class Parser: ) def _parse_table(self, schema=False): + lateral = self._parse_lateral() + + if lateral: + return lateral + unnest = self._parse_unnest() if unnest: @@ -1172,9 +1161,7 @@ class Parser: expressions = self._parse_csv(self._parse_column) self._match_r_paren() - ordinality = bool( - self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY) - ) + ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) alias = self._parse_table_alias() @@ -1280,17 +1267,13 @@ class Parser: if not self._match(TokenType.ORDER_BY): return this - return self.expression( - exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) - ) + return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) def _parse_sort(self, token_type, exp_class): if not self._match(token_type): return None - return self.expression( - exp_class, expressions=self._parse_csv(self._parse_ordered) - ) + return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) def _parse_ordered(self): this = self._parse_conjunction() @@ -1305,22 +1288,17 @@ class Parser: if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") - or (desc and self.null_ordering != "nulls_are_small") + (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") ) and self.null_ordering != "nulls_are_last" ): nulls_first = True - return self.expression( - exp.Ordered, this=this, desc=desc, nulls_first=nulls_first - ) + return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) def _parse_limit(self, this=None, top=False): if self._match(TokenType.TOP if top else TokenType.LIMIT): - return self.expression( - exp.Limit, this=this, expression=self._parse_number() - ) + return self.expression(exp.Limit, this=this, expression=self._parse_number()) if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -1354,7 +1332,7 @@ class Parser: expression, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), - expression=self._parse_select(), + expression=self._parse_select(nested=True), ) def _parse_expression(self): @@ -1396,9 +1374,7 @@ class Parser: this = self.expression(exp.In, this=this, unnest=unnest) else: self._match_l_paren() - expressions = self._parse_csv( - lambda: self._parse_select() or self._parse_expression() - ) + expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression()) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): this = self.expression(exp.In, this=this, query=expressions[0]) @@ -1430,13 +1406,9 @@ class Parser: expression=self._parse_term(), ) elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression( - exp.BitwiseLeftShift, this=this, expression=self._parse_term() - ) + this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression( - exp.BitwiseRightShift, this=this, expression=self._parse_term() - ) + this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) else: break @@ -1524,7 +1496,7 @@ class Parser: self.raise_error("Expecting >") if type_token in self.TIMESTAMPS: - tz = self._match(TokenType.WITH_TIME_ZONE) + tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ self._match(TokenType.WITHOUT_TIME_ZONE) if tz: return exp.DataType( @@ -1594,16 +1566,14 @@ class Parser: if query: expressions = [query] else: - expressions = self._parse_csv( - lambda: self._parse_alias(self._parse_conjunction(), explicit=True) - ) + expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) this = list_get(expressions, 0) self._parse_query_modifiers(this) self._match_r_paren() if isinstance(this, exp.Subqueryable): - return self._parse_subquery(this) + return self._parse_set_operations(self._parse_subquery(this)) if len(expressions) > 1: return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Paren, this=this) @@ -1611,11 +1581,7 @@ class Parser: return None def _parse_field(self, any_token=False): - return ( - self._parse_primary() - or self._parse_function() - or self._parse_id_var(any_token) - ) + return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) def _parse_function(self): if not self._curr: @@ -1628,21 +1594,22 @@ class Parser: if not self._next or self._next.token_type != TokenType.L_PAREN: if token_type in self.NO_PAREN_FUNCTIONS: - return self.expression( - self._advance() or self.NO_PAREN_FUNCTIONS[token_type] - ) + return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type]) return None if token_type not in self.FUNC_TOKENS: return None - if self._match_set(self.FUNCTION_PARSERS): - self._advance() - this = self.FUNCTION_PARSERS[token_type](self, token_type) + this = self._curr.text + upper = this.upper() + self._advance(2) + + parser = self.FUNCTION_PARSERS.get(upper) + + if parser: + this = parser(self) else: subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - this = self._curr.text - self._advance(2) if subquery_predicate and self._curr.token_type in ( TokenType.SELECT, @@ -1652,7 +1619,7 @@ class Parser: self._match_r_paren() return this - function = self.FUNCTIONS.get(this.upper()) + function = self.FUNCTIONS.get(upper) args = self._parse_csv(self._parse_lambda) if function: @@ -1700,10 +1667,7 @@ class Parser: self._retreat(index) return this - args = self._parse_csv( - lambda: self._parse_constraint() - or self._parse_column_def(self._parse_field()) - ) + args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -1720,12 +1684,9 @@ class Parser: break constraints.append(constraint) - return self.expression( - exp.ColumnDef, this=this, kind=kind, constraints=constraints - ) + return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) def _parse_column_constraint(self): - kind = None this = None if self._match(TokenType.CONSTRAINT): @@ -1735,28 +1696,28 @@ class Parser: kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): self._match_l_paren() - kind = self.expression( - exp.CheckColumnConstraint, this=self._parse_conjunction() - ) + kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction()) self._match_r_paren() elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression( - exp.DefaultColumnConstraint, this=self._parse_field() - ) - elif self._match(TokenType.NOT) and self._match(TokenType.NULL): + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field()) + elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.SCHEMA_COMMENT): - kind = self.expression( - exp.CommentColumnConstraint, this=self._parse_string() - ) + kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): kind = exp.PrimaryKeyColumnConstraint() elif self._match(TokenType.UNIQUE): kind = exp.UniqueColumnConstraint() - - if kind is None: + elif self._match(TokenType.GENERATED): + if self._match(TokenType.BY_DEFAULT): + kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) + else: + self._match(TokenType.ALWAYS) + kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) + self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) + else: return None return self.expression(exp.ColumnConstraint, this=this, kind=kind) @@ -1864,9 +1825,7 @@ class Parser: if not self._match(TokenType.END): self.raise_error("Expected END after CASE", self._prev) - return self._parse_window( - self.expression(exp.Case, this=expression, ifs=ifs, default=default) - ) + return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) def _parse_if(self): if self._match(TokenType.L_PAREN): @@ -1889,7 +1848,7 @@ class Parser: if not self._match(TokenType.FROM): self.raise_error("Expected FROM after EXTRACT", self._prev) - return self.expression(exp.Extract, this=this, expression=self._parse_type()) + return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) def _parse_cast(self, strict): this = self._parse_conjunction() @@ -1917,12 +1876,54 @@ class Parser: to = None return self.expression(exp.Cast, this=this, to=to) + def _parse_substring(self): + # Postgres supports the form: substring(string [from int] [for int]) + # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 + + args = self._parse_csv(self._parse_bitwise) + + if self._match(TokenType.FROM): + args.append(self._parse_bitwise()) + if self._match(TokenType.FOR): + args.append(self._parse_bitwise()) + + this = exp.Substring.from_arg_list(args) + self.validate_expression(this, args) + + return this + + def _parse_trim(self): + # https://www.w3resource.com/sql/character-functions/trim.php + # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html + + position = None + collation = None + + if self._match_set(self.TRIM_TYPES): + position = self._prev.text.upper() + + expression = self._parse_term() + if self._match(TokenType.FROM): + this = self._parse_term() + else: + this = expression + expression = None + + if self._match(TokenType.COLLATE): + collation = self._parse_term() + + return self.expression( + exp.Trim, + this=this, + position=position, + expression=expression, + collation=collation, + ) + def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): self._match_l_paren() - this = self.expression( - exp.Filter, this=this, expression=self._parse_where() - ) + this = self.expression(exp.Filter, this=this, expression=self._parse_where()) self._match_r_paren() if self._match(TokenType.WITHIN_GROUP): @@ -1935,6 +1936,25 @@ class Parser: self._match_r_paren() return this + # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER + # Some dialects choose to implement and some do not. + # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html + + # There is some code above in _parse_lambda that handles + # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ... + + # The below changes handle + # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ... + + # Oracle allows both formats + # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) + # and Snowflake chose to do the same for familiarity + # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + elif self._match(TokenType.RESPECT_NULLS): + this = self.expression(exp.RespectNulls, this=this) + # bigquery select from window x AS (partition by ...) if alias: self._match(TokenType.ALIAS) @@ -1992,13 +2012,9 @@ class Parser: self._match(TokenType.BETWEEN) return { - "value": ( - self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) - and self._prev.text - ) + "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) or self._parse_bitwise(), - "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) - and self._prev.text, + "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } def _parse_alias(self, this, explicit=False): @@ -2023,22 +2039,16 @@ class Parser: return this - def _parse_id_var(self, any_token=True): + def _parse_id_var(self, any_token=True, tokens=None): identifier = self._parse_identifier() if identifier: return identifier - if ( - any_token - and self._curr - and self._curr.token_type not in self.RESERVED_KEYWORDS - ): + if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) - return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier( - this=self._prev.text, quoted=False - ) + return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) def _parse_string(self): if self._match(TokenType.STRING): @@ -2077,9 +2087,7 @@ class Parser: def _parse_star(self): if self._match(TokenType.STAR): - return exp.Star( - **{"except": self._parse_except(), "replace": self._parse_replace()} - ) + return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) return None def _parse_placeholder(self): @@ -2117,15 +2125,10 @@ class Parser: this = parse() while self._match_set(expressions): - this = self.expression( - expressions[self._prev.token_type], this=this, expression=parse() - ) + this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) return this - def _parse_all(self, parse): - return list(iter(parse, None)) - def _parse_wrapped_id_vars(self): self._match_l_paren() expressions = self._parse_csv(self._parse_id_var) @@ -2156,10 +2159,7 @@ class Parser: if not self._curr or not self._next: return None - if ( - self._curr.token_type == token_type_a - and self._next.token_type == token_type_b - ): + if self._curr.token_type == token_type_a and self._next.token_type == token_type_b: if advance: self._advance(2) return True diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 2006a75..ed0b66c 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -72,9 +72,7 @@ class Step: if from_: from_ = from_.expressions if len(from_) > 1: - raise UnsupportedError( - "Multi-from statements are unsupported. Run it through the optimizer" - ) + raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer") step = Scan.from_expression(from_[0], ctes) else: @@ -104,9 +102,7 @@ class Step: continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" - operand.replace( - exp.column(operands[operand], step.name, quoted=True) - ) + operand.replace(exp.column(operands[operand], step.name, quoted=True)) else: projections.append(e) @@ -121,14 +117,9 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name - aggregate.operands = tuple( - alias(operand, alias_) for operand, alias_ in operands.items() - ) + aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) aggregate.aggregations = aggregations - aggregate.group = [ - exp.column(e.alias_or_name, step.name, quoted=True) - for e in group.expressions - ] + aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions] aggregate.add_dependency(step) step = aggregate @@ -212,9 +203,7 @@ class Scan(Step): alias_ = expression.alias if not alias_: - raise UnsupportedError( - "Tables/Subqueries must be aliased. Run it through the optimizer" - ) + raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer") if isinstance(expression, exp.Subquery): step = Step.from_expression(table, ctes) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index e4b754d..bd95bc7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -38,6 +38,7 @@ class TokenType(AutoName): DARROW = auto() HASH_ARROW = auto() DHASH_ARROW = auto() + LR_ARROW = auto() ANNOTATION = auto() DOLLAR = auto() @@ -53,6 +54,7 @@ class TokenType(AutoName): TABLE = auto() VAR = auto() BIT_STRING = auto() + HEX_STRING = auto() # types BOOLEAN = auto() @@ -78,10 +80,17 @@ class TokenType(AutoName): UUID = auto() GEOGRAPHY = auto() NULLABLE = auto() + GEOMETRY = auto() + HLLSKETCH = auto() + SUPER = auto() + SERIAL = auto() + SMALLSERIAL = auto() + BIGSERIAL = auto() # keywords ADD_FILE = auto() ALIAS = auto() + ALWAYS = auto() ALL = auto() ALTER = auto() ANALYZE = auto() @@ -92,11 +101,12 @@ class TokenType(AutoName): AUTO_INCREMENT = auto() BEGIN = auto() BETWEEN = auto() + BOTH = auto() BUCKET = auto() + BY_DEFAULT = auto() CACHE = auto() CALL = auto() CASE = auto() - CAST = auto() CHARACTER_SET = auto() CHECK = auto() CLUSTER_BY = auto() @@ -104,7 +114,6 @@ class TokenType(AutoName): COMMENT = auto() COMMIT = auto() CONSTRAINT = auto() - CONVERT = auto() CREATE = auto() CROSS = auto() CUBE = auto() @@ -127,22 +136,24 @@ class TokenType(AutoName): EXCEPT = auto() EXISTS = auto() EXPLAIN = auto() - EXTRACT = auto() FALSE = auto() FETCH = auto() FILTER = auto() FINAL = auto() FIRST = auto() FOLLOWING = auto() + FOR = auto() FOREIGN_KEY = auto() FORMAT = auto() FULL = auto() FUNCTION = auto() FROM = auto() + GENERATED = auto() GROUP_BY = auto() GROUPING_SETS = auto() HAVING = auto() HINT = auto() + IDENTITY = auto() IF = auto() IGNORE_NULLS = auto() ILIKE = auto() @@ -159,12 +170,14 @@ class TokenType(AutoName): JOIN = auto() LATERAL = auto() LAZY = auto() + LEADING = auto() LEFT = auto() LIKE = auto() LIMIT = auto() LOCATION = auto() MAP = auto() MOD = auto() + NATURAL = auto() NEXT = auto() NO_ACTION = auto() NULL = auto() @@ -204,8 +217,10 @@ class TokenType(AutoName): ROWS = auto() SCHEMA_COMMENT = auto() SELECT = auto() + SEPARATOR = auto() SET = auto() SHOW = auto() + SIMILAR_TO = auto() SOME = auto() SORT_BY = auto() STORED = auto() @@ -213,12 +228,11 @@ class TokenType(AutoName): TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() - TIME = auto() TOP = auto() THEN = auto() TRUE = auto() + TRAILING = auto() TRUNCATE = auto() - TRY_CAST = auto() UNBOUNDED = auto() UNCACHE = auto() UNION = auto() @@ -272,35 +286,32 @@ class _Tokenizer(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) - klass.QUOTES = dict( - (quote, quote) if isinstance(quote, str) else (quote[0], quote[1]) - for quote in klass.QUOTES - ) - - klass.IDENTIFIERS = dict( - (identifier, identifier) - if isinstance(identifier, str) - else (identifier[0], identifier[1]) - for identifier in klass.IDENTIFIERS - ) - - klass.COMMENTS = dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) - for comment in klass.COMMENTS + klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) + klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) + klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) + klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) + klass._COMMENTS = dict( + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS ) klass.KEYWORD_TRIE = new_trie( key.upper() for key, value in { **klass.KEYWORDS, - **{comment: TokenType.COMMENT for comment in klass.COMMENTS}, - **{quote: TokenType.QUOTE for quote in klass.QUOTES}, + **{comment: TokenType.COMMENT for comment in klass._COMMENTS}, + **{quote: TokenType.QUOTE for quote in klass._QUOTES}, + **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, + **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, }.items() if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) return klass + @staticmethod + def _delimeter_list_to_dict(list): + return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) + class Tokenizer(metaclass=_Tokenizer): SINGLE_TOKENS = { @@ -339,6 +350,10 @@ class Tokenizer(metaclass=_Tokenizer): QUOTES = ["'"] + BIT_STRINGS = [] + + HEX_STRINGS = [] + IDENTIFIERS = ['"'] ESCAPE = "'" @@ -357,6 +372,7 @@ class Tokenizer(metaclass=_Tokenizer): "->>": TokenType.DARROW, "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, + "<->": TokenType.LR_ARROW, "ADD ARCHIVE": TokenType.ADD_FILE, "ADD ARCHIVES": TokenType.ADD_FILE, "ADD FILE": TokenType.ADD_FILE, @@ -374,12 +390,12 @@ class Tokenizer(metaclass=_Tokenizer): "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "BEGIN": TokenType.BEGIN, "BETWEEN": TokenType.BETWEEN, + "BOTH": TokenType.BOTH, "BUCKET": TokenType.BUCKET, "CALL": TokenType.CALL, "CACHE": TokenType.CACHE, "UNCACHE": TokenType.UNCACHE, "CASE": TokenType.CASE, - "CAST": TokenType.CAST, "CHARACTER SET": TokenType.CHARACTER_SET, "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, @@ -387,7 +403,6 @@ class Tokenizer(metaclass=_Tokenizer): "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, "CONSTRAINT": TokenType.CONSTRAINT, - "CONVERT": TokenType.CONVERT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, "CUBE": TokenType.CUBE, @@ -408,7 +423,6 @@ class Tokenizer(metaclass=_Tokenizer): "EXCEPT": TokenType.EXCEPT, "EXISTS": TokenType.EXISTS, "EXPLAIN": TokenType.EXPLAIN, - "EXTRACT": TokenType.EXTRACT, "FALSE": TokenType.FALSE, "FETCH": TokenType.FETCH, "FILTER": TokenType.FILTER, @@ -437,10 +451,12 @@ class Tokenizer(metaclass=_Tokenizer): "JOIN": TokenType.JOIN, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, + "LEADING": TokenType.LEADING, "LEFT": TokenType.LEFT, "LIKE": TokenType.LIKE, "LIMIT": TokenType.LIMIT, "LOCATION": TokenType.LOCATION, + "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NO ACTION": TokenType.NO_ACTION, "NOT": TokenType.NOT, @@ -490,8 +506,8 @@ class Tokenizer(metaclass=_Tokenizer): "TEMPORARY": TokenType.TEMPORARY, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, + "TRAILING": TokenType.TRAILING, "TRUNCATE": TokenType.TRUNCATE, - "TRY_CAST": TokenType.TRY_CAST, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, "UNNEST": TokenType.UNNEST, @@ -626,14 +642,12 @@ class Tokenizer(metaclass=_Tokenizer): break white_space = self.WHITE_SPACE.get(self._char) - identifier_end = self.IDENTIFIERS.get(self._char) + identifier_end = self._IDENTIFIERS.get(self._char) if white_space: if white_space == TokenType.BREAK: self._col = 1 self._line += 1 - elif self._char == "0" and self._peek == "x": - self._scan_hex() elif self._char.isdigit(): self._scan_number() elif identifier_end: @@ -666,9 +680,7 @@ class Tokenizer(metaclass=_Tokenizer): text = self._text if text is None else text self.tokens.append(Token(token_type, text, self._line, self._col)) - if token_type in self.COMMANDS and ( - len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON - ): + if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): self._start = self._current while not self._end and self._peek != ";": self._advance() @@ -725,6 +737,8 @@ class Tokenizer(metaclass=_Tokenizer): if self._scan_string(word): return + if self._scan_numeric_string(word): + return if self._scan_comment(word): return @@ -732,10 +746,10 @@ class Tokenizer(metaclass=_Tokenizer): self._add(self.KEYWORDS[word.upper()]) def _scan_comment(self, comment_start): - if comment_start not in self.COMMENTS: + if comment_start not in self._COMMENTS: return False - comment_end = self.COMMENTS[comment_start] + comment_end = self._COMMENTS[comment_start] if comment_end: comment_end_size = len(comment_end) @@ -749,15 +763,18 @@ class Tokenizer(metaclass=_Tokenizer): return True def _scan_annotation(self): - while ( - not self._end - and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK - and self._peek != "," - ): + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",": self._advance() self._add(TokenType.ANNOTATION, self._text[1:]) def _scan_number(self): + if self._char == "0": + peek = self._peek.upper() + if peek == "B": + return self._scan_bits() + elif peek == "X": + return self._scan_hex() + decimal = False scientific = 0 @@ -788,57 +805,71 @@ class Tokenizer(metaclass=_Tokenizer): else: return self._add(TokenType.NUMBER) + def _scan_bits(self): + self._advance() + value = self._extract_value() + try: + self._add(TokenType.BIT_STRING, f"{int(value, 2)}") + except ValueError: + self._add(TokenType.IDENTIFIER) + def _scan_hex(self): self._advance() + value = self._extract_value() + try: + self._add(TokenType.HEX_STRING, f"{int(value, 16)}") + except ValueError: + self._add(TokenType.IDENTIFIER) + def _extract_value(self): while True: char = self._peek.strip() if char and char not in self.SINGLE_TOKENS: self._advance() else: break - try: - self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}") - except ValueError: - self._add(TokenType.IDENTIFIER) + + return self._text def _scan_string(self, quote): - quote_end = self.QUOTES.get(quote) + quote_end = self._QUOTES.get(quote) if quote_end is None: return False - text = "" self._advance(len(quote)) - quote_end_size = len(quote_end) - - while True: - if self._char == self.ESCAPE and self._peek == quote_end: - text += quote - self._advance(2) - else: - if self._chars(quote_end_size) == quote_end: - if quote_end_size > 1: - self._advance(quote_end_size - 1) - break - - if self._end: - raise RuntimeError( - f"Missing {quote} from {self._line}:{self._start}" - ) - text += self._char - self._advance() + text = self._extract_string(quote_end) text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text self._add(TokenType.STRING, text) return True + def _scan_numeric_string(self, string_start): + if string_start in self._HEX_STRINGS: + delimiters = self._HEX_STRINGS + token_type = TokenType.HEX_STRING + base = 16 + elif string_start in self._BIT_STRINGS: + delimiters = self._BIT_STRINGS + token_type = TokenType.BIT_STRING + base = 2 + else: + return False + + self._advance(len(string_start)) + string_end = delimiters.get(string_start) + text = self._extract_string(string_end) + + try: + self._add(token_type, f"{int(text, base)}") + except ValueError: + raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + return True + def _scan_identifier(self, identifier_end): while self._peek != identifier_end: if self._end: - raise RuntimeError( - f"Missing {identifier_end} from {self._line}:{self._start}" - ) + raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") self._advance() self._advance() self._add(TokenType.IDENTIFIER, self._text[1:-1]) @@ -851,3 +882,24 @@ class Tokenizer(metaclass=_Tokenizer): else: break self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR)) + + def _extract_string(self, delimiter): + text = "" + delim_size = len(delimiter) + + while True: + if self._char == self.ESCAPE and self._peek == delimiter: + text += delimiter + self._advance(2) + else: + if self._chars(delim_size) == delimiter: + if delim_size > 1: + self._advance(delim_size - 1) + break + + if self._end: + raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") + text += self._char + self._advance() + + return text diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index e7ccb8e..7fc71dd 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -12,9 +12,7 @@ def unalias_group(expression): """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { - e.alias: i - for i, e in enumerate(expression.parent.expressions, start=1) - if isinstance(e, exp.Alias) + e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias) } expression = expression.copy() diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3993565..6b7bfd3 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -36,9 +36,7 @@ class Validator(unittest.TestCase): for read_dialect, read_sql in (read or {}).items(): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( - parse_one(read_sql, read_dialect).sql( - self.dialect, unsupported_level=ErrorLevel.IGNORE - ), + parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE), sql, ) @@ -46,9 +44,7 @@ class Validator(unittest.TestCase): with self.subTest(f"{sql} -> {write_dialect}"): if write_sql is UnsupportedError: with self.assertRaises(UnsupportedError): - expression.sql( - write_dialect, unsupported_level=ErrorLevel.RAISE - ) + expression.sql(write_dialect, unsupported_level=ErrorLevel.RAISE) else: self.assertEqual( expression.sql( @@ -82,11 +78,19 @@ class TestDialect(Validator): "oracle": "CAST(a AS CLOB)", "postgres": "CAST(a AS TEXT)", "presto": "CAST(a AS VARCHAR)", + "redshift": "CAST(a AS TEXT)", "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", }, ) + self.validate_all( + "CAST(a AS DATETIME)", + write={ + "postgres": "CAST(a AS TIMESTAMP)", + "sqlite": "CAST(a AS DATETIME)", + }, + ) self.validate_all( "CAST(a AS STRING)", write={ @@ -97,6 +101,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS CLOB)", "postgres": "CAST(a AS TEXT)", "presto": "CAST(a AS VARCHAR)", + "redshift": "CAST(a AS TEXT)", "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", @@ -112,6 +117,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS VARCHAR2)", "postgres": "CAST(a AS VARCHAR)", "presto": "CAST(a AS VARCHAR)", + "redshift": "CAST(a AS VARCHAR)", "snowflake": "CAST(a AS VARCHAR)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS VARCHAR)", @@ -127,6 +133,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS VARCHAR2(3))", "postgres": "CAST(a AS VARCHAR(3))", "presto": "CAST(a AS VARCHAR(3))", + "redshift": "CAST(a AS VARCHAR(3))", "snowflake": "CAST(a AS VARCHAR(3))", "spark": "CAST(a AS VARCHAR(3))", "starrocks": "CAST(a AS VARCHAR(3))", @@ -142,12 +149,26 @@ class TestDialect(Validator): "oracle": "CAST(a AS NUMBER)", "postgres": "CAST(a AS SMALLINT)", "presto": "CAST(a AS SMALLINT)", + "redshift": "CAST(a AS SMALLINT)", "snowflake": "CAST(a AS SMALLINT)", "spark": "CAST(a AS SHORT)", "sqlite": "CAST(a AS INTEGER)", "starrocks": "CAST(a AS SMALLINT)", }, ) + self.validate_all( + "TRY_CAST(a AS DOUBLE)", + read={ + "postgres": "CAST(a AS DOUBLE PRECISION)", + "redshift": "CAST(a AS DOUBLE PRECISION)", + }, + write={ + "duckdb": "TRY_CAST(a AS DOUBLE)", + "postgres": "CAST(a AS DOUBLE PRECISION)", + "redshift": "CAST(a AS DOUBLE PRECISION)", + }, + ) + self.validate_all( "CAST(a AS DOUBLE)", write={ @@ -159,16 +180,32 @@ class TestDialect(Validator): "oracle": "CAST(a AS DOUBLE PRECISION)", "postgres": "CAST(a AS DOUBLE PRECISION)", "presto": "CAST(a AS DOUBLE)", + "redshift": "CAST(a AS DOUBLE PRECISION)", "snowflake": "CAST(a AS DOUBLE)", "spark": "CAST(a AS DOUBLE)", "starrocks": "CAST(a AS DOUBLE)", }, ) self.validate_all( - "CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"} + "CAST('1 DAY' AS INTERVAL)", + write={ + "postgres": "CAST('1 DAY' AS INTERVAL)", + "redshift": "CAST('1 DAY' AS INTERVAL)", + }, ) self.validate_all( - "CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"} + "CAST(a AS TIMESTAMP)", + write={ + "starrocks": "CAST(a AS DATETIME)", + "redshift": "CAST(a AS TIMESTAMP)", + }, + ) + self.validate_all( + "CAST(a AS TIMESTAMPTZ)", + write={ + "starrocks": "CAST(a AS DATETIME)", + "redshift": "CAST(a AS TIMESTAMPTZ)", + }, ) self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"}) @@ -552,6 +589,7 @@ class TestDialect(Validator): write={ "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", @@ -566,6 +604,7 @@ class TestDialect(Validator): "presto": "JSON_EXTRACT(x, 'y')", }, write={ + "oracle": "JSON_EXTRACT(x, 'y')", "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", }, @@ -623,6 +662,37 @@ class TestDialect(Validator): }, ) + # https://dev.mysql.com/doc/refman/8.0/en/join.html + # https://www.postgresql.org/docs/current/queries-table-expressions.html + def test_joined_tables(self): + self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)") + self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)") + self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)") + self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)") + + self.validate_all( + "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", + write={ + "postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", + "mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", + }, + ) + self.validate_all( + "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", + write={ + "postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", + "mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", + }, + ) + + def test_lateral_subquery(self): + self.validate_identity( + "SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art" + ) + self.validate_identity( + "SELECT * FROM tbl AS t LEFT JOIN LATERAL (SELECT * FROM b WHERE b.t_id = t.t_id) AS t ON TRUE" + ) + def test_set_operators(self): self.validate_all( "SELECT * FROM a UNION SELECT * FROM b", @@ -731,6 +801,9 @@ class TestDialect(Validator): ) def test_operators(self): + self.validate_identity("some.column LIKE 'foo' || another.column || 'bar' || LOWER(x)") + self.validate_identity("some.column LIKE 'foo' + another.column + 'bar'") + self.validate_all( "x ILIKE '%y'", read={ @@ -874,16 +947,8 @@ class TestDialect(Validator): "spark": "FILTER(the_array, x -> x > 0)", }, ) - self.validate_all( - "SELECT a AS b FROM x GROUP BY b", - write={ - "duckdb": "SELECT a AS b FROM x GROUP BY b", - "presto": "SELECT a AS b FROM x GROUP BY 1", - "hive": "SELECT a AS b FROM x GROUP BY 1", - "oracle": "SELECT a AS b FROM x GROUP BY 1", - "spark": "SELECT a AS b FROM x GROUP BY 1", - }, - ) + + def test_limit(self): self.validate_all( "SELECT x FROM y LIMIT 10", write={ @@ -915,6 +980,7 @@ class TestDialect(Validator): read={ "clickhouse": '`x` + "y"', "sqlite": '`x` + "y"', + "redshift": '"x" + "y"', }, ) self.validate_all( @@ -977,5 +1043,36 @@ class TestDialect(Validator): "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", + "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))", + }, + ) + + def test_alias(self): + self.validate_all( + "SELECT a AS b FROM x GROUP BY b", + write={ + "duckdb": "SELECT a AS b FROM x GROUP BY b", + "presto": "SELECT a AS b FROM x GROUP BY 1", + "hive": "SELECT a AS b FROM x GROUP BY 1", + "oracle": "SELECT a AS b FROM x GROUP BY 1", + "spark": "SELECT a AS b FROM x GROUP BY 1", + }, + ) + self.validate_all( + "SELECT y x FROM my_table t", + write={ + "hive": "SELECT y AS x FROM my_table AS t", + "oracle": "SELECT y AS x FROM my_table t", + "postgres": "SELECT y AS x FROM my_table AS t", + "sqlite": "SELECT y AS x FROM my_table AS t", + }, + ) + self.validate_all( + "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", + write={ + "hive": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", + "oracle": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 t JOIN cte2 WHERE cte1.a = cte2.c", + "postgres": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", + "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", }, ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index eccd75a..55086e3 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -341,6 +341,21 @@ class TestHive(Validator): "spark": "PERCENTILE(x, 0.5)", }, ) + self.validate_all( + "PERCENTILE_APPROX(x, 0.5)", + read={ + "hive": "PERCENTILE_APPROX(x, 0.5)", + "presto": "APPROX_PERCENTILE(x, 0.5)", + "duckdb": "APPROX_QUANTILE(x, 0.5)", + "spark": "PERCENTILE_APPROX(x, 0.5)", + }, + write={ + "hive": "PERCENTILE_APPROX(x, 0.5)", + "presto": "APPROX_PERCENTILE(x, 0.5)", + "duckdb": "APPROX_QUANTILE(x, 0.5)", + "spark": "PERCENTILE_APPROX(x, 0.5)", + }, + ) self.validate_all( "APPROX_COUNT_DISTINCT(a)", write={ diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index ee0c5f5..87a3d64 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -15,6 +15,10 @@ class TestMySQL(Validator): def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')") + self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") + self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") + self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") def test_introducers(self): self.validate_all( @@ -27,12 +31,22 @@ class TestMySQL(Validator): }, ) - def test_binary_literal(self): + def test_hexadecimal_literal(self): self.validate_all( "SELECT 0xCC", write={ - "mysql": "SELECT b'11001100'", - "spark": "SELECT X'11001100'", + "mysql": "SELECT x'CC'", + "sqlite": "SELECT x'CC'", + "spark": "SELECT X'CC'", + "trino": "SELECT X'CC'", + "bigquery": "SELECT 0xCC", + "oracle": "SELECT 204", + }, + ) + self.validate_all( + "SELECT X'1A'", + write={ + "mysql": "SELECT x'1A'", }, ) self.validate_all( @@ -41,10 +55,22 @@ class TestMySQL(Validator): "mysql": "SELECT `0xz`", }, ) + + def test_bits_literal(self): + self.validate_all( + "SELECT 0b1011", + write={ + "mysql": "SELECT b'1011'", + "postgres": "SELECT b'1011'", + "oracle": "SELECT 11", + }, + ) self.validate_all( - "SELECT 0XCC", + "SELECT B'1011'", write={ - "mysql": "SELECT 0 AS XCC", + "mysql": "SELECT b'1011'", + "postgres": "SELECT b'1011'", + "oracle": "SELECT 11", }, ) @@ -77,3 +103,19 @@ class TestMySQL(Validator): "mysql": "SELECT 1", }, ) + + def test_mysql(self): + self.validate_all( + "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", + write={ + "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')", + "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", + }, + ) + self.validate_all( + "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", + write={ + "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", + "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", + }, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 15dbfd0..e0934d7 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,9 +8,7 @@ class TestPostgres(Validator): def test_ddl(self): self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={ - "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" - }, + write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"}, ) self.validate_all( "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", @@ -42,11 +40,17 @@ class TestPostgres(Validator): " CONSTRAINT valid_discount CHECK (price > discounted_price))" }, ) + self.validate_all( + "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)", + write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"}, + ) + self.validate_all( + "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)", + write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"}, + ) with self.assertRaises(ParseError): - transpile( - "CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres" - ) + transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") with self.assertRaises(ParseError): transpile( "CREATE TABLE products (price DECIMAL, CHECK price > 1)", @@ -54,11 +58,16 @@ class TestPostgres(Validator): ) def test_postgres(self): - self.validate_all( - "CREATE TABLE x (a INT SERIAL)", - read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, - write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, - ) + self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')') + self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") + self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')") + self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") + self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") + self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") + self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ @@ -91,3 +100,65 @@ class TestPostgres(Validator): "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", }, ) + self.validate_all( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END", + write={ + "hive": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END", + "spark": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END", + }, + ) + self.validate_all( + "SELECT * FROM x WHERE SUBSTRING(col1 FROM 3 + LENGTH(col1) - 10 FOR 10) IN (col2)", + write={ + "hive": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)", + "spark": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)", + }, + ) + self.validate_all( + "SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)", + read={ + "postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)", + }, + ) + self.validate_all( + "SELECT TRIM(BOTH ' XXX ')", + write={ + "mysql": "SELECT TRIM(' XXX ')", + "postgres": "SELECT TRIM(' XXX ')", + "hive": "SELECT TRIM(' XXX ')", + }, + ) + self.validate_all( + "TRIM(LEADING FROM ' XXX ')", + write={ + "mysql": "LTRIM(' XXX ')", + "postgres": "LTRIM(' XXX ')", + "hive": "LTRIM(' XXX ')", + "presto": "LTRIM(' XXX ')", + }, + ) + self.validate_all( + "TRIM(TRAILING FROM ' XXX ')", + write={ + "mysql": "RTRIM(' XXX ')", + "postgres": "RTRIM(' XXX ')", + "hive": "RTRIM(' XXX ')", + "presto": "RTRIM(' XXX ')", + }, + ) + self.validate_all( + "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", + read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"}, + ) + self.validate_all( + "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", + read={ + "postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", + }, + ) + self.validate_all( + "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", + read={ + "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id", + }, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py new file mode 100644 index 0000000..1ed2bb6 --- /dev/null +++ b/tests/dialects/test_redshift.py @@ -0,0 +1,64 @@ +from tests.dialects.test_dialect import Validator + + +class TestRedshift(Validator): + dialect = "redshift" + + def test_redshift(self): + self.validate_all( + 'create table "group" ("col" char(10))', + write={ + "redshift": 'CREATE TABLE "group" ("col" CHAR(10))', + "mysql": "CREATE TABLE `group` (`col` CHAR(10))", + }, + ) + self.validate_all( + 'create table if not exists city_slash_id("city/id" integer not null, state char(2) not null)', + write={ + "redshift": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)', + "presto": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)', + }, + ) + self.validate_all( + "SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)", + write={ + "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", + "bigquery": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", + }, + ) + self.validate_all( + "SELECT ST_AsEWKT(ST_GeogFromText('LINESTRING(110 40, 2 3, -10 80, -7 9)')::geometry)", + write={ + "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOGFROMTEXT('LINESTRING(110 40, 2 3, -10 80, -7 9)') AS GEOMETRY))", + }, + ) + self.validate_all( + "SELECT 'abc'::BINARY", + write={ + "redshift": "SELECT CAST('abc' AS VARBYTE)", + }, + ) + self.validate_all( + "SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid", + write={ + "redshift": "SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid", + }, + ) + self.validate_all( + 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\_%\' LIMIT 5', + write={ + "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5' + }, + ) + + def test_identity(self): + self.validate_identity("CAST('bla' AS SUPER)") + self.validate_identity("CREATE TABLE real1 (realcol REAL)") + self.validate_identity("CAST('foo' AS HLLSKETCH)") + self.validate_identity("SELECT DATEADD(day, 1, 'today')") + self.validate_identity("'abc' SIMILAR TO '(b|c)%'") + self.validate_identity( + "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'" + ) + self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)") + self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 62f78e1..2eeff52 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -143,3 +143,35 @@ class TestSnowflake(Validator): "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", }, ) + + def test_null_treatment(self): + self.validate_all( + r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", + write={ + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + }, + ) + self.validate_all( + r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", + write={ + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + }, + ) + self.validate_all( + r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", + write={ + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + }, + ) + self.validate_all( + r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", + write={ + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + }, + ) + self.validate_all( + r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", + write={ + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + }, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index a0576de..3cc974c 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -34,6 +34,7 @@ class TestSQLite(Validator): write={ "sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", + "postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)", }, ) self.validate_all( @@ -70,3 +71,20 @@ class TestSQLite(Validator): "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", }, ) + + def test_hexadecimal_literal(self): + self.validate_all( + "SELECT 0XCC", + write={ + "sqlite": "SELECT x'CC'", + "mysql": "SELECT x'CC'", + }, + ) + + def test_window_null_treatment(self): + self.validate_all( + "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks", + write={ + "sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks" + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 40f11a2..1b4168c 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -318,6 +318,9 @@ SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar +SELECT 1 FROM a NATURAL JOIN b +SELECT 1 FROM a NATURAL LEFT JOIN b +SELECT 1 FROM a NATURAL LEFT OUTER JOIN b SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar SELECT 1 UNION ALL SELECT 2 @@ -329,6 +332,7 @@ SELECT 1 AS delete, 2 AS alter SELECT * FROM (x) SELECT * FROM ((x)) SELECT * FROM ((SELECT 1)) +SELECT * FROM (x LATERAL VIEW EXPLODE(y) JOIN foo) SELECT * FROM (SELECT 1) AS x SELECT * FROM (SELECT 1 UNION SELECT 2) AS x SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x @@ -430,6 +434,7 @@ CREATE TEMPORARY VIEW x AS SELECT a FROM d CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3)) +CREATE TABLE z (end INT) CREATE TABLE z (a ARRAY, b MAP, c DECIMAL(5, 3)) CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3)) CREATE TABLE z (a INT(11) DEFAULT UUID()) @@ -466,6 +471,7 @@ CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1 CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a +CACHE TABLE x AS (SELECT 1 AS y) CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2') INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y @@ -512,3 +518,4 @@ SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z +SELECT ((SELECT 1) + 1) diff --git a/tests/fixtures/optimizer/merge_derived_tables.sql b/tests/fixtures/optimizer/merge_derived_tables.sql new file mode 100644 index 0000000..c5aa7e9 --- /dev/null +++ b/tests/fixtures/optimizer/merge_derived_tables.sql @@ -0,0 +1,63 @@ +-- Simple +SELECT a, b FROM (SELECT a, b FROM x); +SELECT x.a AS a, x.b AS b FROM x AS x; + +-- Inner table alias is merged +SELECT a, b FROM (SELECT a, b FROM x AS q) AS r; +SELECT q.a AS a, q.b AS b FROM x AS q; + +-- Double nesting +SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x)); +SELECT x.a AS a, x.b AS b FROM x AS x; + +-- WHERE clause is merged +SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a; +SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; + +-- Outer query has join +SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; + +-- Join on derived table +SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +-- Inner query has a join +SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b); +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +-- Inner query has conflicting name in outer query +SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b; +SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b; + +-- Inner query has conflicting name in joined source +SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b; +SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b; + +-- Inner query has multiple conflicting names +SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b; +SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b; + +-- Inner queries have conflicting names with each other +SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b; +SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b; + +-- WHERE clause in joined derived table is merged +SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1; + +-- WHERE clause in outer joined derived table is merged to ON clause +SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y; +SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1; + +-- Comma JOIN in outer query +SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y; +SELECT x.a AS a, y.c AS c FROM x AS x, y AS y; + +-- Comma JOIN in inner query +SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x; +SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z; + +-- (Regression) Column in ORDER BY +SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index f7bbdda..f1d0f7d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -2,11 +2,7 @@ SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m; SELECT "z"."a" AS "a", "q"."m" AS "m" -FROM ( - SELECT - "z"."a" AS "a" - FROM "z" AS "z" -) AS "z" +FROM "z" AS "z" LATERAL VIEW EXPLODE(ARRAY(1, 2)) q AS "m"; @@ -91,41 +87,26 @@ FROM ( WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 GROUP BY a; SELECT - "d"."a" AS "a", - SUM("d"."b") AS "_col_1" -FROM ( + "x"."a" AS "a", + SUM("y"."b") AS "_col_1" +FROM "x" AS "x" +LEFT JOIN ( SELECT - "x"."a" AS "a", - "y"."b" AS "b" - FROM ( - SELECT - "x"."a" AS "a" - FROM "x" AS "x" - WHERE - "x"."a" > 1 - ) AS "x" - LEFT JOIN ( - SELECT - MAX("y"."b") AS "_col_0", - "y"."a" AS "_u_1" - FROM "y" AS "y" - GROUP BY - "y"."a" - ) AS "_u_0" - ON "x"."a" = "_u_0"."_u_1" - JOIN ( - SELECT - "y"."a" AS "a", - "y"."b" AS "b" - FROM "y" AS "y" - ) AS "y" - ON "x"."a" = "y"."a" - WHERE - "_u_0"."_col_0" >= 0 - AND NOT "_u_0"."_u_1" IS NULL -) AS "d" + MAX("y"."b") AS "_col_0", + "y"."a" AS "_u_1" + FROM "y" AS "y" + GROUP BY + "y"."a" +) AS "_u_0" + ON "x"."a" = "_u_0"."_u_1" +JOIN "y" AS "y" + ON "x"."a" = "y"."a" +WHERE + "_u_0"."_col_0" >= 0 + AND "x"."a" > 1 + AND NOT "_u_0"."_u_1" IS NULL GROUP BY - "d"."a"; + "x"."a"; (SELECT a FROM x) LIMIT 1; ( diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 482e231..0b6d382 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -120,36 +120,16 @@ SELECT "supplier"."s_address" AS "s_address", "supplier"."s_phone" AS "s_phone", "supplier"."s_comment" AS "s_comment" -FROM ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_mfgr" AS "p_mfgr", - "part"."p_type" AS "p_type", - "part"."p_size" AS "p_size" - FROM "part" AS "part" - WHERE - "part"."p_size" = 15 - AND "part"."p_type" LIKE '%BRASS' -) AS "part" +FROM "part" AS "part" LEFT JOIN ( SELECT MIN("partsupp"."ps_supplycost") AS "_col_0", "partsupp"."ps_partkey" AS "_u_1" FROM "_e_0" AS "partsupp" CROSS JOIN "_e_1" AS "region" - JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_regionkey" AS "n_regionkey" - FROM "nation" AS "nation" - ) AS "nation" + JOIN "nation" AS "nation" ON "nation"."n_regionkey" = "region"."r_regionkey" - JOIN ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" - ) AS "supplier" + JOIN "supplier" AS "supplier" ON "supplier"."s_nationkey" = "nation"."n_nationkey" AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" GROUP BY @@ -157,31 +137,17 @@ LEFT JOIN ( ) AS "_u_0" ON "part"."p_partkey" = "_u_0"."_u_1" CROSS JOIN "_e_1" AS "region" -JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name", - "nation"."n_regionkey" AS "n_regionkey" - FROM "nation" AS "nation" -) AS "nation" +JOIN "nation" AS "nation" ON "nation"."n_regionkey" = "region"."r_regionkey" JOIN "_e_0" AS "partsupp" ON "part"."p_partkey" = "partsupp"."ps_partkey" -JOIN ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_name" AS "s_name", - "supplier"."s_address" AS "s_address", - "supplier"."s_nationkey" AS "s_nationkey", - "supplier"."s_phone" AS "s_phone", - "supplier"."s_acctbal" AS "s_acctbal", - "supplier"."s_comment" AS "s_comment" - FROM "supplier" AS "supplier" -) AS "supplier" +JOIN "supplier" AS "supplier" ON "supplier"."s_nationkey" = "nation"."n_nationkey" AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" WHERE - "partsupp"."ps_supplycost" = "_u_0"."_col_0" + "part"."p_size" = 15 + AND "part"."p_type" LIKE '%BRASS' + AND "partsupp"."ps_supplycost" = "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL ORDER BY "s_acctbal" DESC, @@ -224,36 +190,15 @@ SELECT )) AS "revenue", CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", "orders"."o_shippriority" AS "o_shippriority" -FROM ( - SELECT - "customer"."c_custkey" AS "c_custkey", - "customer"."c_mktsegment" AS "c_mktsegment" - FROM "customer" AS "customer" - WHERE - "customer"."c_mktsegment" = 'BUILDING' -) AS "customer" -JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey", - "orders"."o_orderdate" AS "o_orderdate", - "orders"."o_shippriority" AS "o_shippriority" - FROM "orders" AS "orders" - WHERE - "orders"."o_orderdate" < '1995-03-15' -) AS "orders" +FROM "customer" AS "customer" +JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" -JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount", - "lineitem"."l_shipdate" AS "l_shipdate" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_shipdate" > '1995-03-15' -) AS "lineitem" +JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" +WHERE + "customer"."c_mktsegment" = 'BUILDING' + AND "lineitem"."l_shipdate" > '1995-03-15' + AND "orders"."o_orderdate" < '1995-03-15' GROUP BY "lineitem"."l_orderkey", "orders"."o_orderdate", @@ -342,57 +287,22 @@ SELECT SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" -FROM ( - SELECT - "customer"."c_custkey" AS "c_custkey", - "customer"."c_nationkey" AS "c_nationkey" - FROM "customer" AS "customer" -) AS "customer" -JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey", - "orders"."o_orderdate" AS "o_orderdate" - FROM "orders" AS "orders" - WHERE - "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) -) AS "orders" +FROM "customer" AS "customer" +JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" -CROSS JOIN ( - SELECT - "region"."r_regionkey" AS "r_regionkey", - "region"."r_name" AS "r_name" - FROM "region" AS "region" - WHERE - "region"."r_name" = 'ASIA' -) AS "region" -JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name", - "nation"."n_regionkey" AS "n_regionkey" - FROM "nation" AS "nation" -) AS "nation" +CROSS JOIN "region" AS "region" +JOIN "nation" AS "nation" ON "nation"."n_regionkey" = "region"."r_regionkey" -JOIN ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" -) AS "supplier" +JOIN "supplier" AS "supplier" ON "customer"."c_nationkey" = "supplier"."s_nationkey" AND "supplier"."s_nationkey" = "nation"."n_nationkey" -JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_suppkey" AS "l_suppkey", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount" - FROM "lineitem" AS "lineitem" -) AS "lineitem" +JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_suppkey" = "supplier"."s_suppkey" +WHERE + "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) + AND "region"."r_name" = 'ASIA' GROUP BY "nation"."n_name" ORDER BY @@ -471,67 +381,37 @@ WITH "_e_0" AS ( OR "nation"."n_name" = 'GERMANY' ) SELECT - "shipping"."supp_nation" AS "supp_nation", - "shipping"."cust_nation" AS "cust_nation", - "shipping"."l_year" AS "l_year", - SUM("shipping"."volume") AS "revenue" -FROM ( - SELECT - "n1"."n_name" AS "supp_nation", - "n2"."n_name" AS "cust_nation", - EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", - "lineitem"."l_extendedprice" * ( + "n1"."n_name" AS "supp_nation", + "n2"."n_name" AS "cust_nation", + EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", + SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" - ) AS "volume" - FROM ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" - ) AS "supplier" - JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_suppkey" AS "l_suppkey", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount", - "lineitem"."l_shipdate" AS "l_shipdate" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) - ) AS "lineitem" - ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" - JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey" - FROM "orders" AS "orders" - ) AS "orders" - ON "orders"."o_orderkey" = "lineitem"."l_orderkey" - JOIN ( - SELECT - "customer"."c_custkey" AS "c_custkey", - "customer"."c_nationkey" AS "c_nationkey" - FROM "customer" AS "customer" - ) AS "customer" - ON "customer"."c_custkey" = "orders"."o_custkey" - JOIN "_e_0" AS "n1" - ON "supplier"."s_nationkey" = "n1"."n_nationkey" - JOIN "_e_0" AS "n2" - ON "customer"."c_nationkey" = "n2"."n_nationkey" - AND ( - "n1"."n_name" = 'FRANCE' - OR "n2"."n_name" = 'FRANCE' - ) - AND ( - "n1"."n_name" = 'GERMANY' - OR "n2"."n_name" = 'GERMANY' - ) -) AS "shipping" + )) AS "revenue" +FROM "supplier" AS "supplier" +JOIN "lineitem" AS "lineitem" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" +JOIN "orders" AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +JOIN "customer" AS "customer" + ON "customer"."c_custkey" = "orders"."o_custkey" +JOIN "_e_0" AS "n1" + ON "supplier"."s_nationkey" = "n1"."n_nationkey" +JOIN "_e_0" AS "n2" + ON "customer"."c_nationkey" = "n2"."n_nationkey" + AND ( + "n1"."n_name" = 'FRANCE' + OR "n2"."n_name" = 'FRANCE' + ) + AND ( + "n1"."n_name" = 'GERMANY' + OR "n2"."n_name" = 'GERMANY' + ) +WHERE + "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) GROUP BY - "shipping"."supp_nation", - "shipping"."cust_nation", - "shipping"."l_year" + "n1"."n_name", + "n2"."n_name", + EXTRACT(year FROM "lineitem"."l_shipdate") ORDER BY "supp_nation", "cust_nation", @@ -578,87 +458,37 @@ group by order by o_year; SELECT - "all_nations"."o_year" AS "o_year", + EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", SUM(CASE - WHEN "all_nations"."nation" = 'BRAZIL' - THEN "all_nations"."volume" + WHEN "nation_2"."n_name" = 'BRAZIL' + THEN "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) ELSE 0 - END) / SUM("all_nations"."volume") AS "mkt_share" -FROM ( - SELECT - EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", - "lineitem"."l_extendedprice" * ( + END) / SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" - ) AS "volume", - "n2"."n_name" AS "nation" - FROM ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_type" AS "p_type" - FROM "part" AS "part" - WHERE - "part"."p_type" = 'ECONOMY ANODIZED STEEL' - ) AS "part" - CROSS JOIN ( - SELECT - "region"."r_regionkey" AS "r_regionkey", - "region"."r_name" AS "r_name" - FROM "region" AS "region" - WHERE - "region"."r_name" = 'AMERICA' - ) AS "region" - JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_regionkey" AS "n_regionkey" - FROM "nation" AS "nation" - ) AS "n1" - ON "n1"."n_regionkey" = "region"."r_regionkey" - JOIN ( - SELECT - "customer"."c_custkey" AS "c_custkey", - "customer"."c_nationkey" AS "c_nationkey" - FROM "customer" AS "customer" - ) AS "customer" - ON "customer"."c_nationkey" = "n1"."n_nationkey" - JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey", - "orders"."o_orderdate" AS "o_orderdate" - FROM "orders" AS "orders" - WHERE - "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) - ) AS "orders" - ON "orders"."o_custkey" = "customer"."c_custkey" - JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_partkey" AS "l_partkey", - "lineitem"."l_suppkey" AS "l_suppkey", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount" - FROM "lineitem" AS "lineitem" - ) AS "lineitem" - ON "lineitem"."l_orderkey" = "orders"."o_orderkey" - AND "part"."p_partkey" = "lineitem"."l_partkey" - JOIN ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" - ) AS "supplier" - ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" - JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name" - FROM "nation" AS "nation" - ) AS "n2" - ON "supplier"."s_nationkey" = "n2"."n_nationkey" -) AS "all_nations" + )) AS "mkt_share" +FROM "part" AS "part" +CROSS JOIN "region" AS "region" +JOIN "nation" AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" +JOIN "customer" AS "customer" + ON "customer"."c_nationkey" = "nation"."n_nationkey" +JOIN "orders" AS "orders" + ON "orders"."o_custkey" = "customer"."c_custkey" +JOIN "lineitem" AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "part"."p_partkey" = "lineitem"."l_partkey" +JOIN "supplier" AS "supplier" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" +JOIN "nation" AS "nation_2" + ON "supplier"."s_nationkey" = "nation_2"."n_nationkey" +WHERE + "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND "part"."p_type" = 'ECONOMY ANODIZED STEEL' + AND "region"."r_name" = 'AMERICA' GROUP BY - "all_nations"."o_year" + EXTRACT(year FROM "orders"."o_orderdate") ORDER BY "o_year"; @@ -698,69 +528,28 @@ order by nation, o_year desc; SELECT - "profit"."nation" AS "nation", - "profit"."o_year" AS "o_year", - SUM("profit"."amount") AS "sum_profit" -FROM ( - SELECT - "nation"."n_name" AS "nation", - EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", - "lineitem"."l_extendedprice" * ( + "nation"."n_name" AS "nation", + EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" - ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity" AS "amount" - FROM ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_name" AS "p_name" - FROM "part" AS "part" - WHERE - "part"."p_name" LIKE '%green%' - ) AS "part" - JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_partkey" AS "l_partkey", - "lineitem"."l_suppkey" AS "l_suppkey", - "lineitem"."l_quantity" AS "l_quantity", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount" - FROM "lineitem" AS "lineitem" - ) AS "lineitem" - ON "part"."p_partkey" = "lineitem"."l_partkey" - JOIN ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" - ) AS "supplier" - ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" - JOIN ( - SELECT - "partsupp"."ps_partkey" AS "ps_partkey", - "partsupp"."ps_suppkey" AS "ps_suppkey", - "partsupp"."ps_supplycost" AS "ps_supplycost" - FROM "partsupp" AS "partsupp" - ) AS "partsupp" - ON "partsupp"."ps_partkey" = "lineitem"."l_partkey" - AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey" - JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_orderdate" AS "o_orderdate" - FROM "orders" AS "orders" - ) AS "orders" - ON "orders"."o_orderkey" = "lineitem"."l_orderkey" - JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name" - FROM "nation" AS "nation" - ) AS "nation" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" -) AS "profit" + ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity") AS "sum_profit" +FROM "part" AS "part" +JOIN "lineitem" AS "lineitem" + ON "part"."p_partkey" = "lineitem"."l_partkey" +JOIN "supplier" AS "supplier" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" +JOIN "partsupp" AS "partsupp" + ON "partsupp"."ps_partkey" = "lineitem"."l_partkey" + AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey" +JOIN "orders" AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +JOIN "nation" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +WHERE + "part"."p_name" LIKE '%green%' GROUP BY - "profit"."nation", - "profit"."o_year" + "nation"."n_name", + EXTRACT(year FROM "orders"."o_orderdate") ORDER BY "nation", "o_year" DESC; @@ -812,46 +601,17 @@ SELECT "customer"."c_address" AS "c_address", "customer"."c_phone" AS "c_phone", "customer"."c_comment" AS "c_comment" -FROM ( - SELECT - "customer"."c_custkey" AS "c_custkey", - "customer"."c_name" AS "c_name", - "customer"."c_address" AS "c_address", - "customer"."c_nationkey" AS "c_nationkey", - "customer"."c_phone" AS "c_phone", - "customer"."c_acctbal" AS "c_acctbal", - "customer"."c_comment" AS "c_comment" - FROM "customer" AS "customer" -) AS "customer" -JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey", - "orders"."o_orderdate" AS "o_orderdate" - FROM "orders" AS "orders" - WHERE - "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) -) AS "orders" +FROM "customer" AS "customer" +JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" -JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount", - "lineitem"."l_returnflag" AS "l_returnflag" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_returnflag" = 'R' -) AS "lineitem" +JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" -JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name" - FROM "nation" AS "nation" -) AS "nation" +JOIN "nation" AS "nation" ON "customer"."c_nationkey" = "nation"."n_nationkey" +WHERE + "lineitem"."l_returnflag" = 'R' + AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) GROUP BY "customer"."c_custkey", "customer"."c_name", @@ -910,14 +670,7 @@ WITH "_e_0" AS ( SELECT "partsupp"."ps_partkey" AS "ps_partkey", SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" -FROM ( - SELECT - "partsupp"."ps_partkey" AS "ps_partkey", - "partsupp"."ps_suppkey" AS "ps_suppkey", - "partsupp"."ps_availqty" AS "ps_availqty", - "partsupp"."ps_supplycost" AS "ps_supplycost" - FROM "partsupp" AS "partsupp" -) AS "partsupp" +FROM "partsupp" AS "partsupp" JOIN "_e_0" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" JOIN "_e_1" AS "nation" @@ -928,13 +681,7 @@ HAVING SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( SELECT SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" - FROM ( - SELECT - "partsupp"."ps_suppkey" AS "ps_suppkey", - "partsupp"."ps_availqty" AS "ps_availqty", - "partsupp"."ps_supplycost" AS "ps_supplycost" - FROM "partsupp" AS "partsupp" - ) AS "partsupp" + FROM "partsupp" AS "partsupp" JOIN "_e_0" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" JOIN "_e_1" AS "nation" @@ -988,28 +735,15 @@ SELECT THEN 1 ELSE 0 END) AS "low_line_count" -FROM ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_orderpriority" AS "o_orderpriority" - FROM "orders" AS "orders" -) AS "orders" -JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_shipdate" AS "l_shipdate", - "lineitem"."l_commitdate" AS "l_commitdate", - "lineitem"."l_receiptdate" AS "l_receiptdate", - "lineitem"."l_shipmode" AS "l_shipmode" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" - AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) - AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" - AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') -) AS "lineitem" +FROM "orders" AS "orders" +JOIN "lineitem" AS "lineitem" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +WHERE + "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" + AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) + AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" + AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') GROUP BY "lineitem"."l_shipmode" ORDER BY @@ -1044,21 +778,10 @@ SELECT FROM ( SELECT COUNT("orders"."o_orderkey") AS "c_count" - FROM ( - SELECT - "customer"."c_custkey" AS "c_custkey" - FROM "customer" AS "customer" - ) AS "customer" - LEFT JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey", - "orders"."o_comment" AS "o_comment" - FROM "orders" AS "orders" - WHERE - NOT "orders"."o_comment" LIKE '%special%requests%' - ) AS "orders" + FROM "customer" AS "customer" + LEFT JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" + AND NOT "orders"."o_comment" LIKE '%special%requests%' GROUP BY "customer"."c_custkey" ) AS "c_orders" @@ -1094,24 +817,12 @@ SELECT END) / SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "promo_revenue" -FROM ( - SELECT - "lineitem"."l_partkey" AS "l_partkey", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount", - "lineitem"."l_shipdate" AS "l_shipdate" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE) -) AS "lineitem" -JOIN ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_type" AS "p_type" - FROM "part" AS "part" -) AS "part" - ON "lineitem"."l_partkey" = "part"."p_partkey"; +FROM "lineitem" AS "lineitem" +JOIN "part" AS "part" + ON "lineitem"."l_partkey" = "part"."p_partkey" +WHERE + "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE); -------------------------------------- -- TPC-H 15 @@ -1165,14 +876,7 @@ SELECT "supplier"."s_address" AS "s_address", "supplier"."s_phone" AS "s_phone", "revenue"."total_revenue" AS "total_revenue" -FROM ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_name" AS "s_name", - "supplier"."s_address" AS "s_address", - "supplier"."s_phone" AS "s_phone" - FROM "supplier" AS "supplier" -) AS "supplier" +FROM "supplier" AS "supplier" JOIN "revenue" ON "revenue"."total_revenue" = ( SELECT @@ -1221,12 +925,7 @@ SELECT "part"."p_type" AS "p_type", "part"."p_size" AS "p_size", COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt" -FROM ( - SELECT - "partsupp"."ps_partkey" AS "ps_partkey", - "partsupp"."ps_suppkey" AS "ps_suppkey" - FROM "partsupp" AS "partsupp" -) AS "partsupp" +FROM "partsupp" AS "partsupp" LEFT JOIN ( SELECT "supplier"."s_suppkey" AS "s_suppkey" @@ -1237,21 +936,13 @@ LEFT JOIN ( "supplier"."s_suppkey" ) AS "_u_0" ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey" -JOIN ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_brand" AS "p_brand", - "part"."p_type" AS "p_type", - "part"."p_size" AS "p_size" - FROM "part" AS "part" - WHERE - "part"."p_brand" <> 'Brand#45' - AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9) - AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%' -) AS "part" +JOIN "part" AS "part" ON "part"."p_partkey" = "partsupp"."ps_partkey" WHERE "_u_0"."s_suppkey" IS NULL + AND "part"."p_brand" <> 'Brand#45' + AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9) + AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%' GROUP BY "part"."p_brand", "part"."p_type", @@ -1284,23 +975,8 @@ where ); SELECT SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly" -FROM ( - SELECT - "lineitem"."l_partkey" AS "l_partkey", - "lineitem"."l_quantity" AS "l_quantity", - "lineitem"."l_extendedprice" AS "l_extendedprice" - FROM "lineitem" AS "lineitem" -) AS "lineitem" -JOIN ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_brand" AS "p_brand", - "part"."p_container" AS "p_container" - FROM "part" AS "part" - WHERE - "part"."p_brand" = 'Brand#23' - AND "part"."p_container" = 'MED BOX' -) AS "part" +FROM "lineitem" AS "lineitem" +JOIN "part" AS "part" ON "part"."p_partkey" = "lineitem"."l_partkey" LEFT JOIN ( SELECT @@ -1313,6 +989,8 @@ LEFT JOIN ( ON "_u_0"."_u_1" = "part"."p_partkey" WHERE "lineitem"."l_quantity" < "_u_0"."_col_0" + AND "part"."p_brand" = 'Brand#23' + AND "part"."p_container" = 'MED BOX' AND NOT "_u_0"."_u_1" IS NULL; -------------------------------------- @@ -1359,20 +1037,8 @@ SELECT "orders"."o_orderdate" AS "o_orderdate", "orders"."o_totalprice" AS "o_totalprice", SUM("lineitem"."l_quantity") AS "_col_5" -FROM ( - SELECT - "customer"."c_custkey" AS "c_custkey", - "customer"."c_name" AS "c_name" - FROM "customer" AS "customer" -) AS "customer" -JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_custkey" AS "o_custkey", - "orders"."o_totalprice" AS "o_totalprice", - "orders"."o_orderdate" AS "o_orderdate" - FROM "orders" AS "orders" -) AS "orders" +FROM "customer" AS "customer" +JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" LEFT JOIN ( SELECT @@ -1385,12 +1051,7 @@ LEFT JOIN ( SUM("lineitem"."l_quantity") > 300 ) AS "_u_0" ON "orders"."o_orderkey" = "_u_0"."l_orderkey" -JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_quantity" AS "l_quantity" - FROM "lineitem" AS "lineitem" -) AS "lineitem" +JOIN "lineitem" AS "lineitem" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" WHERE NOT "_u_0"."l_orderkey" IS NULL @@ -1447,24 +1108,8 @@ SELECT SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" -FROM ( - SELECT - "lineitem"."l_partkey" AS "l_partkey", - "lineitem"."l_quantity" AS "l_quantity", - "lineitem"."l_extendedprice" AS "l_extendedprice", - "lineitem"."l_discount" AS "l_discount", - "lineitem"."l_shipinstruct" AS "l_shipinstruct", - "lineitem"."l_shipmode" AS "l_shipmode" - FROM "lineitem" AS "lineitem" -) AS "lineitem" -JOIN ( - SELECT - "part"."p_partkey" AS "p_partkey", - "part"."p_brand" AS "p_brand", - "part"."p_size" AS "p_size", - "part"."p_container" AS "p_container" - FROM "part" AS "part" -) AS "part" +FROM "lineitem" AS "lineitem" +JOIN "part" AS "part" ON ( "part"."p_brand" = 'Brand#12' AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') @@ -1558,14 +1203,7 @@ order by SELECT "supplier"."s_name" AS "s_name", "supplier"."s_address" AS "s_address" -FROM ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_name" AS "s_name", - "supplier"."s_address" AS "s_address", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" -) AS "supplier" +FROM "supplier" AS "supplier" LEFT JOIN ( SELECT "partsupp"."ps_suppkey" AS "ps_suppkey" @@ -1604,17 +1242,11 @@ LEFT JOIN ( "partsupp"."ps_suppkey" ) AS "_u_4" ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey" -JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name" - FROM "nation" AS "nation" - WHERE - "nation"."n_name" = 'CANADA' -) AS "nation" +JOIN "nation" AS "nation" ON "supplier"."s_nationkey" = "nation"."n_nationkey" WHERE - NOT "_u_4"."ps_suppkey" IS NULL + "nation"."n_name" = 'CANADA' + AND NOT "_u_4"."ps_suppkey" IS NULL ORDER BY "s_name"; @@ -1665,24 +1297,9 @@ limit SELECT "supplier"."s_name" AS "s_name", COUNT(*) AS "numwait" -FROM ( - SELECT - "supplier"."s_suppkey" AS "s_suppkey", - "supplier"."s_name" AS "s_name", - "supplier"."s_nationkey" AS "s_nationkey" - FROM "supplier" AS "supplier" -) AS "supplier" -JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey", - "lineitem"."l_suppkey" AS "l_suppkey", - "lineitem"."l_commitdate" AS "l_commitdate", - "lineitem"."l_receiptdate" AS "l_receiptdate" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_receiptdate" > "lineitem"."l_commitdate" -) AS "l1" - ON "supplier"."s_suppkey" = "l1"."l_suppkey" +FROM "supplier" AS "supplier" +JOIN "lineitem" AS "lineitem" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" LEFT JOIN ( SELECT "l2"."l_orderkey" AS "l_orderkey", @@ -1691,7 +1308,7 @@ LEFT JOIN ( GROUP BY "l2"."l_orderkey" ) AS "_u_0" - ON "_u_0"."l_orderkey" = "l1"."l_orderkey" + ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey" LEFT JOIN ( SELECT "l3"."l_orderkey" AS "l_orderkey", @@ -1702,31 +1319,20 @@ LEFT JOIN ( GROUP BY "l3"."l_orderkey" ) AS "_u_2" - ON "_u_2"."l_orderkey" = "l1"."l_orderkey" -JOIN ( - SELECT - "orders"."o_orderkey" AS "o_orderkey", - "orders"."o_orderstatus" AS "o_orderstatus" - FROM "orders" AS "orders" - WHERE - "orders"."o_orderstatus" = 'F' -) AS "orders" - ON "orders"."o_orderkey" = "l1"."l_orderkey" -JOIN ( - SELECT - "nation"."n_nationkey" AS "n_nationkey", - "nation"."n_name" AS "n_name" - FROM "nation" AS "nation" - WHERE - "nation"."n_name" = 'SAUDI ARABIA' -) AS "nation" + ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey" +JOIN "orders" AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +JOIN "nation" AS "nation" ON "supplier"."s_nationkey" = "nation"."n_nationkey" WHERE ( "_u_2"."l_orderkey" IS NULL - OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "l1"."l_suppkey") + OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey") ) - AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "l1"."l_suppkey") + AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate" + AND "nation"."n_name" = 'SAUDI ARABIA' + AND "orders"."o_orderstatus" = 'F' + AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey") AND NOT "_u_0"."l_orderkey" IS NULL GROUP BY "supplier"."s_name" @@ -1776,35 +1382,30 @@ group by order by cntrycode; SELECT - "custsale"."cntrycode" AS "cntrycode", + SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", COUNT(*) AS "numcust", - SUM("custsale"."c_acctbal") AS "totacctbal" -FROM ( + SUM("customer"."c_acctbal") AS "totacctbal" +FROM "customer" AS "customer" +LEFT JOIN ( SELECT - SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", - "customer"."c_acctbal" AS "c_acctbal" - FROM "customer" AS "customer" - LEFT JOIN ( + "orders"."o_custkey" AS "_u_1" + FROM "orders" AS "orders" + GROUP BY + "orders"."o_custkey" +) AS "_u_0" + ON "_u_0"."_u_1" = "customer"."c_custkey" +WHERE + "_u_0"."_u_1" IS NULL + AND "customer"."c_acctbal" > ( SELECT - "orders"."o_custkey" AS "_u_1" - FROM "orders" AS "orders" - GROUP BY - "orders"."o_custkey" - ) AS "_u_0" - ON "_u_0"."_u_1" = "customer"."c_custkey" - WHERE - "_u_0"."_u_1" IS NULL - AND "customer"."c_acctbal" > ( - SELECT - AVG("customer"."c_acctbal") AS "_col_0" - FROM "customer" AS "customer" - WHERE - "customer"."c_acctbal" > 0.00 - AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') - ) - AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') -) AS "custsale" + AVG("customer"."c_acctbal") AS "_col_0" + FROM "customer" AS "customer" + WHERE + "customer"."c_acctbal" > 0.00 + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') + ) + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') GROUP BY - "custsale"."cntrycode" + SUBSTRING("customer"."c_phone", 1, 2) ORDER BY "cntrycode"; diff --git a/tests/helpers.py b/tests/helpers.py index d4edb14..ad50483 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,9 +5,7 @@ FIXTURES_DIR = os.path.join(FILE_DIR, "fixtures") def _filter_comments(s): - return "\n".join( - [line for line in s.splitlines() if line and not line.startswith("--")] - ) + return "\n".join([line for line in s.splitlines() if line and not line.startswith("--")]) def _extract_meta(sql): @@ -23,9 +21,7 @@ def _extract_meta(sql): def assert_logger_contains(message, logger, level="error"): - output = "\n".join( - str(args[0][0]) for args in getattr(logger, level).call_args_list - ) + output = "\n".join(str(args[0][0]) for args in getattr(logger, level).call_args_list) assert message in output diff --git a/tests/test_build.py b/tests/test_build.py index a4cffde..18c0e47 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -46,10 +46,7 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl WHERE FALSE", ), ( - lambda: select("x") - .from_("tbl") - .where("x > 0") - .where("x < 9", append=False), + lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False), "SELECT x FROM tbl WHERE x < 9", ), ( @@ -61,10 +58,7 @@ class TestBuild(unittest.TestCase): "SELECT x, y FROM tbl GROUP BY x, y", ), ( - lambda: select("x", "y", "z", "a") - .from_("tbl") - .group_by("x, y", "z") - .group_by("a"), + lambda: select("x", "y", "z", "a").from_("tbl").group_by("x, y", "z").group_by("a"), "SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a", ), ( @@ -85,9 +79,7 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y", ), ( - lambda: select("x") - .from_("tbl") - .join("tbl2", on=["tbl.y = tbl2.y", "a = b"]), + lambda: select("x").from_("tbl").join("tbl2", on=["tbl.y = tbl2.y", "a = b"]), "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b", ), ( @@ -95,21 +87,15 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT OUTER JOIN tbl2", ), ( - lambda: select("x") - .from_("tbl") - .join(exp.Table(this="tbl2"), join_type="left outer"), + lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"), "SELECT x FROM tbl LEFT OUTER JOIN tbl2", ), ( - lambda: select("x") - .from_("tbl") - .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), + lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", ), ( - lambda: select("x") - .from_("tbl") - .join(select("y").from_("tbl2"), join_type="left outer"), + lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"), "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", ), ( @@ -132,9 +118,7 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", ), ( - lambda: select("x") - .from_("tbl") - .join(parse_one("left join x", into=exp.Join), on="a=b"), + lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"), "SELECT x FROM tbl LEFT JOIN x ON a = b", ), ( @@ -142,9 +126,7 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT JOIN x ON a = b", ), ( - lambda: select("x") - .from_("tbl") - .join("select b from tbl2", on="a=b", join_type="left"), + lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"), "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", ), ( @@ -159,10 +141,7 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b", ), ( - lambda: select("x", "COUNT(y)") - .from_("tbl") - .group_by("x") - .having("COUNT(y) > 0"), + lambda: select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 0"), "SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0", ), ( @@ -190,24 +169,15 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl SORT BY x, y DESC", ), ( - lambda: select("x", "y", "z", "a") - .from_("tbl") - .order_by("x, y", "z") - .order_by("a"), + lambda: select("x", "y", "z", "a").from_("tbl").order_by("x, y", "z").order_by("a"), "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", ), ( - lambda: select("x", "y", "z", "a") - .from_("tbl") - .cluster_by("x, y", "z") - .cluster_by("a"), + lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"), "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", ), ( - lambda: select("x", "y", "z", "a") - .from_("tbl") - .sort_by("x, y", "z") - .sort_by("a"), + lambda: select("x", "y", "z", "a").from_("tbl").sort_by("x, y", "z").sort_by("a"), "SELECT x, y, z, a FROM tbl SORT BY x, y, z, a", ), (lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"), @@ -220,21 +190,15 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x") - .from_("tbl") - .with_("tbl", as_="SELECT x FROM tbl2", recursive=True), + lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True), "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x") - .from_("tbl") - .with_("tbl", as_=select("x").from_("tbl2")), + lambda: select("x").from_("tbl").with_("tbl", as_=select("x").from_("tbl2")), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x") - .from_("tbl") - .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), + lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", ), ( @@ -245,72 +209,43 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", ), ( - lambda: select("x") - .from_("tbl") - .with_("tbl", as_=select("x", "y").from_("tbl2")) - .select("y"), + lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"), "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl"), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .group_by("x"), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .order_by("x"), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .limit(10), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .offset(10), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .join("tbl3"), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .distinct(), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(), "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .where("x > 10"), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", ), ( - lambda: select("x") - .with_("tbl", as_=select("x").from_("tbl2")) - .from_("tbl") - .having("x > 20"), + lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", ), (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), @@ -324,9 +259,7 @@ class TestBuild(unittest.TestCase): ), (lambda: from_("tbl").select("x"), "SELECT x FROM tbl"), ( - lambda: parse_one("SELECT a FROM tbl") - .assert_is(exp.Select) - .select("b"), + lambda: parse_one("SELECT a FROM tbl").assert_is(exp.Select).select("b"), "SELECT a, b FROM tbl", ), ( @@ -368,15 +301,11 @@ class TestBuild(unittest.TestCase): "SELECT * FROM x WHERE y = 1 AND z = 1", ), ( - lambda: exp.subquery("select x from tbl", "foo") - .select("x") - .where("x > 0"), + lambda: exp.subquery("select x from tbl", "foo").select("x").where("x > 0"), "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", ), ( - lambda: exp.subquery( - "select x from tbl UNION select x from bar", "unioned" - ).select("x"), + lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"), "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", ), ]: diff --git a/tests/test_executor.py b/tests/test_executor.py index 9afa225..c5841d3 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -27,10 +27,7 @@ class TestExecutor(unittest.TestCase): ) cls.cache = {} - cls.sqls = [ - (sql, expected) - for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") - ] + cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")] @classmethod def tearDownClass(cls): @@ -50,18 +47,17 @@ class TestExecutor(unittest.TestCase): self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''") def test_optimized_tpch(self): - for sql, optimized in self.sqls[0:20]: - a = self.cached_execute(sql) - b = self.conn.execute(optimized).fetchdf() - self.rename_anonymous(b, a) - assert_frame_equal(a, b) + for i, (sql, optimized) in enumerate(self.sqls[:20], start=1): + with self.subTest(f"{i}, {sql}"): + a = self.cached_execute(sql) + b = self.conn.execute(optimized).fetchdf() + self.rename_anonymous(b, a) + assert_frame_equal(a, b) def test_execute_tpch(self): def to_csv(expression): if isinstance(expression, exp.Table): - return parse_one( - f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" - ) + return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}") return expression for sql, _ in self.sqls[0:3]: diff --git a/tests/test_expressions.py b/tests/test_expressions.py index eaef022..716e457 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -26,9 +26,7 @@ class TestExpressions(unittest.TestCase): parse_one("ROW() OVER(Partition by y)"), parse_one("ROW() OVER (partition BY y)"), ) - self.assertEqual( - parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)") - ) + self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) def test_find(self): expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") @@ -87,9 +85,7 @@ class TestExpressions(unittest.TestCase): self.assertIsNone(column.find_ancestor(exp.Join)) def test_alias_or_name(self): - expression = parse_one( - "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" - ) + expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "e", "*", "zz", "z"], @@ -118,9 +114,7 @@ class TestExpressions(unittest.TestCase): ) def test_named_selects(self): - expression = parse_one( - "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" - ) + expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) expression = parse_one( @@ -196,15 +190,9 @@ class TestExpressions(unittest.TestCase): def test_sql(self): self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2") - self.assertEqual( - parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`" - ) - self.assertEqual( - parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"' - ) - self.assertEqual( - parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")' - ) + self.assertEqual(parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`") + self.assertEqual(parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"') + self.assertEqual(parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")') def test_transform_with_arguments(self): expression = parse_one("a") @@ -229,15 +217,11 @@ class TestExpressions(unittest.TestCase): return node actual_expression_1 = expression.transform(fun) - self.assertEqual( - actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" - ) + self.assertEqual(actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIsNot(actual_expression_1, expression) actual_expression_2 = expression.transform(fun, copy=False) - self.assertEqual( - actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" - ) + self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIs(actual_expression_2, expression) with self.assertRaises(ValueError): @@ -274,12 +258,8 @@ class TestExpressions(unittest.TestCase): expression = parse_one("SELECT * FROM (SELECT * FROM x)") self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9) - self.assertTrue( - all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()) - ) - self.assertTrue( - all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) - ) + self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) + self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))) def test_functions(self): self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) @@ -303,9 +283,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If) self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap) self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract) - self.assertIsInstance( - parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar - ) + self.assertIsInstance(parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar) self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least) self.assertIsInstance(parse_one("LN(a)"), exp.Ln) self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10) @@ -334,6 +312,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate) self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime) self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix) + self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim) self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd) self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate) self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring) @@ -404,12 +383,8 @@ class TestExpressions(unittest.TestCase): self.assertFalse(exp.to_identifier("x").quoted) def test_function_normalizer(self): - self.assertEqual( - parse_one("HELLO()").sql(normalize_functions="lower"), "hello()" - ) - self.assertEqual( - parse_one("hello()").sql(normalize_functions="upper"), "HELLO()" - ) + self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()") + self.assertEqual(parse_one("hello()").sql(normalize_functions="upper"), "HELLO()") self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()") self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)") self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 40540b3..102e141 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -31,9 +31,7 @@ class TestOptimizer(unittest.TestCase): dialect = meta.get("dialect") with self.subTest(sql): self.assertEqual( - func(parse_one(sql, read=dialect), **kwargs).sql( - pretty=pretty, dialect=dialect - ), + func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect), expected, ) @@ -86,9 +84,7 @@ class TestOptimizer(unittest.TestCase): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): with self.subTest(sql): with self.assertRaises(OptimizeError): - optimizer.qualify_columns.qualify_columns( - parse_one(sql), schema=self.schema - ) + optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema) def test_quote_identities(self): self.check_file("quote_identities", optimizer.quote_identities.quote_identities) @@ -100,9 +96,7 @@ class TestOptimizer(unittest.TestCase): expression = optimizer.pushdown_projections.pushdown_projections(expression) return expression - self.check_file( - "pushdown_projections", pushdown_projections, schema=self.schema - ) + self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) def test_simplify(self): self.check_file("simplify", optimizer.simplify.simplify) @@ -115,9 +109,7 @@ class TestOptimizer(unittest.TestCase): ) def test_pushdown_predicates(self): - self.check_file( - "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates - ) + self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates) def test_expand_multi_table_selects(self): self.check_file( @@ -138,10 +130,17 @@ class TestOptimizer(unittest.TestCase): pretty=True, ) + def test_merge_derived_tables(self): + def optimize(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.merge_derived_tables.merge_derived_tables(expression) + return expression + + self.check_file("merge_derived_tables", optimize, schema=self.schema) + def test_tpch(self): - self.check_file( - "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True - ) + self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) def test_schema(self): schema = ensure_schema( @@ -262,9 +261,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(len(scopes), 5) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual( - scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b" - ) + self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) diff --git a/tests/test_parser.py b/tests/test_parser.py index 779083d..1054103 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -16,28 +16,23 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("array", into=exp.DataType), exp.DataType) def test_column(self): - columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all( - exp.Column - ) + columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column) assert len(list(columns)) == 1 self.assertIsNotNone(parse_one("date").find(exp.Column)) def test_table(self): - tables = [ - t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table) - ] + tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] self.assertEqual(tables, ["a", "b.c", "d"]) def test_select(self): - self.assertIsNotNone( - parse_one("select * from (select 1) x order by x.y").args["order"] - ) - self.assertIsNotNone( - parse_one("select * from x where a = (select 1) order by x.y").args["order"] - ) + self.assertIsNotNone(parse_one("select 1 natural")) + self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"]) + self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"]) + self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1) self.assertEqual( - len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1 + parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), + """SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""", ) def test_command(self): @@ -72,12 +67,8 @@ class TestParser(unittest.TestCase): ) assert len(expressions) == 2 - assert ( - expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" - ) - assert ( - expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" - ) + assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" + assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) @@ -147,13 +138,9 @@ class TestParser(unittest.TestCase): def test_pretty_config_override(self): self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") with patch("sqlglot.pretty", True): - self.assertEqual( - parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x" - ) + self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x") - self.assertEqual( - parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x" - ) + self.assertEqual(parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x") @patch("sqlglot.parser.logger") def test_comment_error_n(self, logger): diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 28bcc7a..4bec2ac 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -42,6 +42,20 @@ class TestTranspile(unittest.TestCase): "SELECT * FROM x WHERE a = ANY (SELECT 1)", ) + def test_leading_comma(self): + self.validate( + "SELECT FOO, BAR, BAZ", + "SELECT\n FOO\n , BAR\n , BAZ", + leading_comma=True, + pretty=True, + ) + # without pretty, this should be a no-op + self.validate( + "SELECT FOO, BAR, BAZ", + "SELECT FOO, BAR, BAZ", + leading_comma=True, + ) + def test_space(self): self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)") self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)") @@ -108,6 +122,11 @@ class TestTranspile(unittest.TestCase): "extract(month from '2021-01-31'::timestamp without time zone)", "EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))", ) + self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)") + self.validate( + "EXTRACT(minute FROM datetime1 - datetime2)", + "EXTRACT(minute FROM datetime1 - datetime2)", + ) def test_if(self): self.validate( @@ -122,18 +141,14 @@ class TestTranspile(unittest.TestCase): "SELECT IF a > 1 THEN b ELSE c END", "SELECT CASE WHEN a > 1 THEN b ELSE c END", ) - self.validate( - "SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo" - ) + self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo") def test_ignore_nulls(self): self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") def test_time(self): self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") - self.validate( - "TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)" - ) + self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") self.validate( "TIMESTAMP(9) WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ(9))", @@ -159,9 +174,7 @@ class TestTranspile(unittest.TestCase): self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)") self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)") self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb") - self.validate( - "STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb" - ) + self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb") self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb") self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") self.validate( @@ -209,12 +222,8 @@ class TestTranspile(unittest.TestCase): self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None) self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive") - self.validate( - "UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive" - ) - self.validate( - "STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive" - ) + self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive") + self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive") self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto") self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive") self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive") @@ -232,9 +241,7 @@ class TestTranspile(unittest.TestCase): ) self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto") - self.validate( - "STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto" - ) + self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto") self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto") self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto") self.validate( @@ -245,9 +252,7 @@ class TestTranspile(unittest.TestCase): self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto") self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark") - self.validate( - "STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark" - ) + self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark") self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark") self.validate( @@ -283,9 +288,7 @@ class TestTranspile(unittest.TestCase): def test_partial(self): for sql in load_sql_fixtures("partial.sql"): with self.subTest(sql): - self.assertEqual( - transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip() - ) + self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip()) def test_pretty(self): for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"): -- cgit v1.2.3