From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/__main__.py | 7 +- sqlglot/dialects/__init__.py | 1 + sqlglot/dialects/bigquery.py | 9 +- sqlglot/dialects/dialect.py | 31 ++- sqlglot/dialects/duckdb.py | 5 +- sqlglot/dialects/hive.py | 15 +- sqlglot/dialects/mysql.py | 29 +++ sqlglot/dialects/oracle.py | 8 + sqlglot/dialects/postgres.py | 116 ++++++++- sqlglot/dialects/presto.py | 6 +- sqlglot/dialects/redshift.py | 34 +++ sqlglot/dialects/snowflake.py | 4 +- sqlglot/dialects/spark.py | 15 +- sqlglot/dialects/sqlite.py | 1 + sqlglot/dialects/trino.py | 3 + sqlglot/diff.py | 35 +-- sqlglot/executor/__init__.py | 10 +- sqlglot/executor/context.py | 4 +- sqlglot/executor/python.py | 14 +- sqlglot/executor/table.py | 5 +- sqlglot/expressions.py | 169 ++++++++---- sqlglot/generator.py | 167 ++++++------ sqlglot/optimizer/__init__.py | 2 +- sqlglot/optimizer/isolate_table_selects.py | 4 +- sqlglot/optimizer/merge_derived_tables.py | 232 +++++++++++++++++ sqlglot/optimizer/normalize.py | 22 +- sqlglot/optimizer/optimize_joins.py | 6 +- sqlglot/optimizer/optimizer.py | 39 ++- sqlglot/optimizer/pushdown_predicates.py | 20 +- sqlglot/optimizer/qualify_columns.py | 36 +-- sqlglot/optimizer/qualify_tables.py | 4 +- sqlglot/optimizer/schema.py | 4 +- sqlglot/optimizer/scope.py | 58 ++--- sqlglot/optimizer/simplify.py | 8 +- sqlglot/optimizer/unnest_subqueries.py | 22 +- sqlglot/parser.py | 404 ++++++++++++++--------------- sqlglot/planner.py | 21 +- sqlglot/tokens.py | 184 ++++++++----- sqlglot/transforms.py | 4 +- 40 files changed, 1082 insertions(+), 678 deletions(-) create mode 100644 sqlglot/dialects/redshift.py create mode 100644 sqlglot/optimizer/merge_derived_tables.py (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 0007e34..3fa40ce 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -20,7 +20,7 @@ from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.0.4" +__version__ = "6.1.1" pretty = False diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index 25200c4..4161259 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -49,12 +49,7 @@ args = parser.parse_args() error_level = sqlglot.ErrorLevel[args.error_level.upper()] if args.parse: - sqls = [ - repr(expression) - for expression in sqlglot.parse( - args.sql, read=args.read, error_level=error_level - ) - ] + sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)] else: sqls = sqlglot.transpile( args.sql, diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 5aa7d77..f7d03ad 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.oracle import Oracle from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.presto import Presto +from sqlglot.dialects.redshift import Redshift from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.spark import Spark from sqlglot.dialects.sqlite import SQLite diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index f4e87c3..1f1f90a 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -44,6 +44,7 @@ class BigQuery(Dialect): ] IDENTIFIERS = ["`"] ESCAPE = "\\" + HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **Tokenizer.KEYWORDS, @@ -120,9 +121,5 @@ class BigQuery(Dialect): def intersect_op(self, expression): if not expression.args.get("distinct", False): - self.unsupported( - "INTERSECT without DISTINCT is not supported in BigQuery" - ) - return ( - f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" - ) + self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") + return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8045f7a..f338c81 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -20,6 +20,7 @@ class Dialects(str, Enum): ORACLE = "oracle" POSTGRES = "postgres" PRESTO = "presto" + REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SPARK = "spark" SQLITE = "sqlite" @@ -53,12 +54,19 @@ class _Dialect(type): klass.generator_class = getattr(klass, "Generator", Generator) klass.tokenizer = klass.tokenizer_class() - klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[ - 0 - ] - klass.identifier_start, klass.identifier_end = list( - klass.tokenizer_class.IDENTIFIERS.items() - )[0] + klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] + klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] + + if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: + bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] + klass.generator_class.TRANSFORMS[ + exp.BitString + ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" + if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: + hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] + klass.generator_class.TRANSFORMS[ + exp.HexString + ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" return klass @@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect): return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) def parse_into(self, expression_type, sql, **opts): - return self.parser(**opts).parse_into( - expression_type, self.tokenizer.tokenize(sql), sql - ) + return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) def generate(self, expression, **opts): return self.generator(**opts).generate(expression) @@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect): def rename_func(name): - return ( - lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" - ) + return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" def approx_count_distinct_sql(self, expression): @@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None): return exp_class( this=list_get(args, 0), format=Dialect[dialect].format_time( - list_get(args, 1) - or (Dialect[dialect].time_format if default is True else default) + list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) ), ) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d83a620..ff3a8b1 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -63,10 +63,7 @@ def _sort_array_reverse(args): def _struct_pack_sql(self, expression): - args = [ - self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) - for e in expression.expressions - ] + args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions] return f"STRUCT_PACK({', '.join(args)})" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index e3f3f39..59aa8fa 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression): alias=exp.TableAlias(this=alias.this, columns=[column]), ) ) - for expression, column in zip( - unnest.expressions, alias.columns if alias else [] - ) + for expression, column in zip(unnest.expressions, alias.columns if alias else []) ) return self.join_sql(expression) @@ -206,14 +204,11 @@ class Hive(Dialect): substr=list_get(args, 0), position=list_get(args, 2), ), - "LOG": ( - lambda args: exp.Log.from_arg_list(args) - if len(args) > 1 - else exp.Ln.from_arg_list(args) - ), + "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "MAP": _parse_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, + "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, "COLLECT_SET": exp.SetAgg.from_arg_list, "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, @@ -262,6 +257,7 @@ class Hive(Dialect): HiveMap: _map_sql, exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", exp.Quantile: rename_func("PERCENTILE"), + exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.SafeDivide: no_safe_divide_sql, @@ -296,8 +292,7 @@ class Hive(Dialect): def datatype_sql(self, expression): if ( - expression.this - in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) + expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) and not expression.expressions ): expression = exp.DataType.build("text") diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 93800a6..87a2c41 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression): return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" +def _trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific + if not remove_chars: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target})" + + def _date_add(expression_class): def func(args): interval = list_get(args, 1) @@ -88,9 +103,12 @@ class MySQL(Dialect): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] + BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] + HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] KEYWORDS = { **Tokenizer.KEYWORDS, + "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -145,6 +163,15 @@ class MySQL(Dialect): "STR_TO_DATE": _str_to_date, } + FUNCTION_PARSERS = { + **Parser.FUNCTION_PARSERS, + "GROUP_CONCAT": lambda self: self.expression( + exp.GroupConcat, + this=self._parse_lambda(), + separator=self._match(TokenType.SEPARATOR) and self._parse_field(), + ), + } + class Generator(Generator): NULL_ORDERING_SUPPORTED = False @@ -158,6 +185,8 @@ class MySQL(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, + exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, + exp.Trim: _trim_sql, } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 9c8b6f2..91e30b2 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -51,6 +51,14 @@ class Oracle(Dialect): sep="", ) + def alias_sql(self, expression): + if isinstance(expression.this, exp.Table): + to_sql = self.sql(expression, "alias") + # oracle does not allow "AS" between table and alias + to_sql = f" {to_sql}" if to_sql else "" + return f"{self.sql(expression, 'this')}{to_sql}" + return super().alias_sql(expression) + def offset_sql(self, expression): return f"{super().offset_sql(expression)} ROWS" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 61dff86..c796839 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.transforms import delegate, preprocess def _date_add_sql(kind): @@ -32,11 +33,96 @@ def _date_add_sql(kind): return func +def _lateral_sql(self, expression): + this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subquery): + return f"LATERAL{self.sep()}{this}" + alias = expression.args["alias"] + table = alias.name + table = f" {table}" if table else table + columns = self.expressions(alias, key="columns", flat=True) + columns = f" AS {columns}" if columns else "" + return f"LATERAL{self.sep()}{this}{table}{columns}" + + +def _substring_sql(self, expression): + this = self.sql(expression, "this") + start = self.sql(expression, "start") + length = self.sql(expression, "length") + + from_part = f" FROM {start}" if start else "" + for_part = f" FOR {length}" if length else "" + + return f"SUBSTRING({this}{from_part}{for_part})" + + +def _trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + collation = self.sql(expression, "collation") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific + if not remove_chars and not collation: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + collation = f" COLLATE {collation}" if collation else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" + + +def _auto_increment_to_serial(expression): + auto = expression.find(exp.AutoIncrementColumnConstraint) + + if auto: + expression = expression.copy() + expression.args["constraints"].remove(auto.parent) + kind = expression.args["kind"] + + if kind.this == exp.DataType.Type.INT: + kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL)) + elif kind.this == exp.DataType.Type.SMALLINT: + kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL)) + elif kind.this == exp.DataType.Type.BIGINT: + kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL)) + + return expression + + +def _serial_to_generated(expression): + kind = expression.args["kind"] + + if kind.this == exp.DataType.Type.SERIAL: + data_type = exp.DataType(this=exp.DataType.Type.INT) + elif kind.this == exp.DataType.Type.SMALLSERIAL: + data_type = exp.DataType(this=exp.DataType.Type.SMALLINT) + elif kind.this == exp.DataType.Type.BIGSERIAL: + data_type = exp.DataType(this=exp.DataType.Type.BIGINT) + else: + data_type = None + + if data_type: + expression = expression.copy() + expression.args["kind"].replace(data_type) + constraints = expression.args["constraints"] + generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) + notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) + if notnull not in constraints: + constraints.insert(0, notnull) + if generated not in constraints: + constraints.insert(0, generated) + + return expression + + class Postgres(Dialect): null_ordering = "nulls_are_large" time_format = "'YYYY-MM-DD HH24:MI:SS'" time_mapping = { - "AM": "%p", # AM or PM + "AM": "%p", + "PM": "%p", "D": "%w", # 1-based day of week "DD": "%d", # day of month "DDD": "%j", # zero padded day of year @@ -65,14 +151,25 @@ class Postgres(Dialect): } class Tokenizer(Tokenizer): + BIT_STRINGS = [("b'", "'"), ("B'", "'")] + HEX_STRINGS = [("x'", "'"), ("X'", "'")] KEYWORDS = { **Tokenizer.KEYWORDS, - "SERIAL": TokenType.AUTO_INCREMENT, + "ALWAYS": TokenType.ALWAYS, + "BY DEFAULT": TokenType.BY_DEFAULT, + "IDENTITY": TokenType.IDENTITY, + "FOR": TokenType.FOR, + "GENERATED": TokenType.GENERATED, + "DOUBLE PRECISION": TokenType.DOUBLE, + "BIGSERIAL": TokenType.BIGSERIAL, + "SERIAL": TokenType.SERIAL, + "SMALLSERIAL": TokenType.SMALLSERIAL, "UUID": TokenType.UUID, } class Parser(Parser): STRICT_CAST = False + FUNCTIONS = { **Parser.FUNCTIONS, "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), @@ -86,14 +183,18 @@ class Postgres(Dialect): exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", - } - - TOKEN_MAPPING = { - TokenType.AUTO_INCREMENT: "SERIAL", + exp.DataType.Type.DATETIME: "TIMESTAMP", } TRANSFORMS = { **Generator.TRANSFORMS, + exp.ColumnDef: preprocess( + [ + _auto_increment_to_serial, + _serial_to_generated, + ], + delegate("columndef_sql"), + ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", @@ -102,8 +203,11 @@ class Postgres(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), + exp.Lateral: _lateral_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, + exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index ca913e4..7253f7e 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression): time_format = self.format_time(expression) if time_format and time_format not in (Presto.time_format, Presto.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" - return ( - f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" - ) + return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" def _ts_or_ds_add_sql(self, expression): @@ -141,6 +139,7 @@ class Presto(Dialect): "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, "STRPOS": exp.StrPosition.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } class Generator(Generator): @@ -193,6 +192,7 @@ class Presto(Dialect): exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", exp.Quantile: _quantile_sql, + exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.SortArray: _no_sort_array, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py new file mode 100644 index 0000000..e1f7b78 --- /dev/null +++ b/sqlglot/dialects/redshift.py @@ -0,0 +1,34 @@ +from sqlglot import exp +from sqlglot.dialects.postgres import Postgres +from sqlglot.tokens import TokenType + + +class Redshift(Postgres): + time_format = "'YYYY-MM-DD HH:MI:SS'" + time_mapping = { + **Postgres.time_mapping, + "MON": "%b", + "HH": "%H", + } + + class Tokenizer(Postgres.Tokenizer): + ESCAPE = "\\" + + KEYWORDS = { + **Postgres.Tokenizer.KEYWORDS, + "GEOMETRY": TokenType.GEOMETRY, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "HLLSKETCH": TokenType.HLLSKETCH, + "SUPER": TokenType.SUPER, + "TIME": TokenType.TIMESTAMP, + "TIMETZ": TokenType.TIMESTAMPTZ, + "VARBYTE": TokenType.BINARY, + "SIMILAR TO": TokenType.SIMILAR_TO, + } + + class Generator(Postgres.Generator): + TYPE_MAPPING = { + **Postgres.Generator.TYPE_MAPPING, + exp.DataType.Type.BINARY: "VARBYTE", + exp.DataType.Type.INT: "INTEGER", + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 148dfb5..8d6ee78 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args): # case: [ , ] if second_arg.name not in ["0", "3", "9"]: - raise ValueError( - f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" - ) + raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9") if second_arg.name == "0": timescale = exp.UnixToTime.SECONDS diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 89c7ed5..a331191 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -65,12 +65,11 @@ class Spark(Hive): this=list_get(args, 0), start=exp.Sub( this=exp.Length(this=list_get(args, 0)), - expression=exp.Add( - this=list_get(args, 1), expression=exp.Literal.number(1) - ), + expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), ), length=list_get(args, 1), ), + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } class Generator(Hive.Generator): @@ -82,11 +81,7 @@ class Spark(Hive): } TRANSFORMS = { - **{ - k: v - for k, v in Hive.Generator.TRANSFORMS.items() - if k not in {exp.ArraySort} - }, + **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}}, exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), @@ -102,5 +97,5 @@ class Spark(Hive): HiveMap: _map_sql, } - def bitstring_sql(self, expression): - return f"X'{self.sql(expression, 'this')}'" + class Tokenizer(Hive.Tokenizer): + HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 6cf5022..cfdbe1b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType class SQLite(Dialect): class Tokenizer(Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] + HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] KEYWORDS = { **Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 805106c..9a6f7fe 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -8,3 +8,6 @@ class Trino(Presto): **Presto.Generator.TRANSFORMS, exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } + + class Tokenizer(Presto.Tokenizer): + HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 8eeb4e9..0567c12 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -115,13 +115,8 @@ class ChangeDistiller: for kept_source_node_id, kept_target_node_id in matching_set: source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] - if ( - not isinstance(source_node, LEAF_EXPRESSION_TYPES) - or source_node == target_node - ): - edit_script.extend( - self._generate_move_edits(source_node, target_node, matching_set) - ) + if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: + edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set)) edit_script.append(Keep(source_node, target_node)) else: edit_script.append(Update(source_node, target_node)) @@ -132,9 +127,7 @@ class ChangeDistiller: source_args = [id(e) for e in _expression_only_args(source)] target_args = [id(e) for e in _expression_only_args(target)] - args_lcs = set( - _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set) - ) + args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)) move_edits = [] for a in source_args: @@ -148,14 +141,10 @@ class ChangeDistiller: matching_set = leaves_matching_set.copy() ordered_unmatched_source_nodes = { - id(n[0]): None - for n in self._source.bfs() - if id(n[0]) in self._unmatched_source_nodes + id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes } ordered_unmatched_target_nodes = { - id(n[0]): None - for n in self._target.bfs() - if id(n[0]) in self._unmatched_target_nodes + id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes } for source_node_id in ordered_unmatched_source_nodes: @@ -169,18 +158,13 @@ class ChangeDistiller: max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) if max_leaves_num: common_leaves_num = sum( - 1 if s in source_leaf_ids and t in target_leaf_ids else 0 - for s, t in leaves_matching_set + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set ) leaf_similarity_score = common_leaves_num / max_leaves_num else: leaf_similarity_score = 0.0 - adjusted_t = ( - self.t - if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 - else 0.4 - ) + adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 if leaf_similarity_score >= 0.8 or ( leaf_similarity_score >= adjusted_t @@ -217,10 +201,7 @@ class ChangeDistiller: matching_set = set() while candidate_matchings: _, _, source_leaf, target_leaf = heappop(candidate_matchings) - if ( - id(source_leaf) in self._unmatched_source_nodes - and id(target_leaf) in self._unmatched_target_nodes - ): + if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes: matching_set.add((id(source_leaf), id(target_leaf))) self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_target_nodes.remove(id(target_leaf)) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index a437431..bca9f3e 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -3,11 +3,17 @@ import time from sqlglot import parse_one from sqlglot.executor.python import PythonExecutor -from sqlglot.optimizer import optimize +from sqlglot.optimizer import RULES, optimize +from sqlglot.optimizer.merge_derived_tables import merge_derived_tables from sqlglot.planner import Plan logger = logging.getLogger("sqlglot") +OPTIMIZER_RULES = list(RULES) + +# The executor needs isolated table selects +OPTIMIZER_RULES.remove(merge_derived_tables) + def execute(sql, schema, read=None): """ @@ -28,7 +34,7 @@ def execute(sql, schema, read=None): """ expression = parse_one(sql, read=read) now = time.time() - expression = optimize(expression, schema) + expression = optimize(expression, schema, rules=OPTIMIZER_RULES) logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) plan = Plan(expression) diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index 457bea7..d265a2c 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,9 +19,7 @@ class Context: env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables - self.range_readers = { - name: table.range_reader for name, table in self.tables.items() - } + self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 388a419..610aa4b 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -26,11 +26,7 @@ class PythonExecutor: while queue: node = queue.pop() context = self.context( - { - name: table - for dep in node.dependencies - for name, table in contexts[dep].tables.items() - } + {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} ) running.add(node) @@ -151,9 +147,7 @@ class PythonExecutor: return self.context({name: table for name in ctx.tables}) for name, join in step.joins.items(): - join_context = self.context( - {**join_context.tables, name: context.tables[name]} - ) + join_context = self.context({**join_context.tables, name: context.tables[name]}) if join.get("source_key"): table = self.hash_join(join, source, name, join_context) @@ -247,9 +241,7 @@ class PythonExecutor: if step.operands: source_table = context.tables[source] - operand_table = Table( - source_table.columns + self.table(step.operands).columns - ) + operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 6df49f7..80674cb 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -37,10 +37,7 @@ class Table: break lines.append( - " ".join( - str(row[column]).rjust(widths[column])[0 : widths[column]] - for column in self.columns - ) + " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) ) return "\n".join(lines) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7acc63d..b983bf9 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -47,10 +47,7 @@ class Expression(metaclass=_Expression): return hash( ( self.key, - tuple( - (k, tuple(v) if isinstance(v, list) else v) - for k, v in _norm_args(self).items() - ), + tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), ) ) @@ -116,9 +113,22 @@ class Expression(metaclass=_Expression): item.parent = parent return new + def append(self, arg_key, value): + """ + Appends value to arg_key if it's a list or sets it as a new list. + + Args: + arg_key (str): name of the list expression arg + value (Any): value to append to the list + """ + if not isinstance(self.args.get(arg_key), list): + self.args[arg_key] = [] + self.args[arg_key].append(value) + self._set_parent(arg_key, value) + def set(self, arg_key, value): """ - Sets `arg` to `value`. + Sets `arg_key` to `value`. Args: arg_key (str): name of the expression arg @@ -267,6 +277,14 @@ class Expression(metaclass=_Expression): expression = expression.this return expression + def unalias(self): + """ + Returns the inner expression if this is an Alias. + """ + if isinstance(self, Alias): + return self.this + return self + def unnest_operands(self): """ Returns unnested operands as a tuple. @@ -279,9 +297,7 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs( - prune=lambda n, p, *_: p and not isinstance(n, self.__class__) - ): + for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)): if not isinstance(node, self.__class__): yield node.unnest() if unnest else node @@ -314,9 +330,7 @@ class Expression(metaclass=_Expression): args = { k: ", ".join( - v.to_s(hide_missing=hide_missing, level=level + 1) - if hasattr(v, "to_s") - else str(v) + v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) for v in ensure_list(vs) if v is not None ) @@ -354,9 +368,7 @@ class Expression(metaclass=_Expression): new_node.parent = node.parent return new_node - replace_children( - new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs) - ) + replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)) return new_node def replace(self, expression): @@ -546,6 +558,10 @@ class BitString(Condition): pass +class HexString(Condition): + pass + + class Column(Condition): arg_types = {"this": True, "table": False} @@ -566,35 +582,44 @@ class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} -class AutoIncrementColumnConstraint(Expression): +class ColumnConstraintKind(Expression): pass -class CheckColumnConstraint(Expression): +class AutoIncrementColumnConstraint(ColumnConstraintKind): pass -class CollateColumnConstraint(Expression): +class CheckColumnConstraint(ColumnConstraintKind): pass -class CommentColumnConstraint(Expression): +class CollateColumnConstraint(ColumnConstraintKind): pass -class DefaultColumnConstraint(Expression): +class CommentColumnConstraint(ColumnConstraintKind): pass -class NotNullColumnConstraint(Expression): +class DefaultColumnConstraint(ColumnConstraintKind): pass -class PrimaryKeyColumnConstraint(Expression): +class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): + # this: True -> ALWAYS, this: False -> BY DEFAULT + arg_types = {"this": True, "expression": False} + + +class NotNullColumnConstraint(ColumnConstraintKind): pass -class UniqueColumnConstraint(Expression): +class PrimaryKeyColumnConstraint(ColumnConstraintKind): + pass + + +class UniqueColumnConstraint(ColumnConstraintKind): pass @@ -651,9 +676,7 @@ class Identifier(Expression): return bool(self.args.get("quoted")) def __eq__(self, other): - return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg( - other.this - ) + return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this) def __hash__(self): return hash((self.key, self.this.lower())) @@ -709,9 +732,7 @@ class Literal(Condition): def __eq__(self, other): return ( - isinstance(other, Literal) - and self.this == other.this - and self.args["is_string"] == other.args["is_string"] + isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] ) def __hash__(self): @@ -733,6 +754,7 @@ class Join(Expression): "side": False, "kind": False, "using": False, + "natural": False, } @property @@ -743,6 +765,10 @@ class Join(Expression): def side(self): return self.text("side").upper() + @property + def alias_or_name(self): + return self.this.alias_or_name + def on(self, *expressions, append=True, dialect=None, copy=True, **opts): """ Append to or set the ON expressions. @@ -873,10 +899,6 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} -class Table(Expression): - arg_types = {"this": True, "db": False, "catalog": False} - - class Tuple(Expression): arg_types = {"expressions": False} @@ -986,6 +1008,16 @@ QUERY_MODIFIERS = { } +class Table(Expression): + arg_types = { + "this": True, + "db": False, + "catalog": False, + "laterals": False, + "joins": False, + } + + class Union(Subqueryable, Expression): arg_types = { "with": False, @@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression): join.this.replace(join.this.subquery()) if join_type: - side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + if natural: + join.set("natural", True) if side: join.set("side", side.text) if kind: @@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression): properties_expression = None if properties: properties_str = " ".join( - [ - f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" - for k, v in properties.items() - ] + [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()] ) properties_expression = maybe_parse( properties_str, @@ -1654,6 +1685,7 @@ class DataType(Expression): DECIMAL = auto() BOOLEAN = auto() JSON = auto() + INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() DATE = auto() @@ -1662,15 +1694,19 @@ class DataType(Expression): MAP = auto() UUID = auto() GEOGRAPHY = auto() + GEOMETRY = auto() STRUCT = auto() NULLABLE = auto() + HLLSKETCH = auto() + SUPER = auto() + SERIAL = auto() + SMALLSERIAL = auto() + BIGSERIAL = auto() @classmethod def build(cls, dtype, **kwargs): return DataType( - this=dtype - if isinstance(dtype, DataType.Type) - else DataType.Type[dtype.upper()], + this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], **kwargs, ) @@ -1798,6 +1834,14 @@ class Like(Binary, Predicate): pass +class SimilarTo(Binary, Predicate): + pass + + +class Distance(Binary): + pass + + class LT(Binary, Predicate): pass @@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression): pass +class RespectNulls(Expression): + pass + + # Functions class Func(Condition): """ @@ -1924,9 +1972,7 @@ class Func(Condition): all_arg_keys = list(cls.arg_types) # If this function supports variable length argument treat the last argument as such. - non_var_len_arg_keys = ( - all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys - ) + non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys args_dict = {} arg_idx = 0 @@ -1944,9 +1990,7 @@ class Func(Condition): @classmethod def sql_names(cls): if cls is Func: - raise NotImplementedError( - "SQL name is only supported by concrete function implementations" - ) + raise NotImplementedError("SQL name is only supported by concrete function implementations") if not hasattr(cls, "_sql_names"): cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2178,6 +2222,10 @@ class Greatest(Func): is_var_len_args = True +class GroupConcat(Func): + arg_types = {"this": True, "separator": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -2274,6 +2322,10 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} +class ApproxQuantile(Quantile): + pass + + class Reduce(Func): arg_types = {"this": True, "initial": True, "merge": True, "finish": True} @@ -2306,8 +2358,10 @@ class Split(Func): arg_types = {"this": True, "expression": True} +# Start may be omitted in the case of postgres +# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 class Substring(Func): - arg_types = {"this": True, "start": True, "length": False} + arg_types = {"this": True, "start": False, "length": False} class StrPosition(Func): @@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func): pass +class Trim(Func): + arg_types = { + "this": True, + "position": False, + "expression": False, + "collation": False, + } + + class TsOrDsAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -2455,9 +2518,7 @@ def _all_functions(): obj for _, obj in inspect.getmembers( sys.modules[__name__], - lambda obj: inspect.isclass(obj) - and issubclass(obj, Func) - and obj not in (AggFunc, Anonymous, Func), + lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func), ) ] @@ -2633,9 +2694,7 @@ def _apply_conjunction_builder( def _combine(expressions, operator, dialect=None, **opts): - expressions = [ - condition(expression, dialect=dialect, **opts) for expression in expressions - ] + expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] this = expressions[0] if expressions[1:]: this = _wrap_operator(this) @@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None): quoted = not re.match(SAFE_IDENTIFIER_RE, alias) identifier = Identifier(this=alias, quoted=quoted) else: - raise ValueError( - f"Alias needs to be a string or an Identifier, got: {alias.__class__}" - ) + raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}") return identifier diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 793cff0..a445178 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -41,6 +41,8 @@ class Generator: max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 + leading_comma (bool): if the the comma is leading or trailing in select statements + Default: False """ TRANSFORMS = { @@ -108,6 +110,7 @@ class Generator: "_indent", "_replace_backslash", "_escaped_quote_end", + "_leading_comma", ) def __init__( @@ -131,6 +134,7 @@ class Generator: unsupported_level=ErrorLevel.WARN, null_ordering=None, max_unsupported=3, + leading_comma=False, ): import sqlglot @@ -157,6 +161,7 @@ class Generator: self._indent = indent self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end + self._leading_comma = leading_comma def generate(self, expression): """ @@ -178,9 +183,7 @@ class Generator: for msg in self.unsupported_messages: logger.warning(msg) elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError( - concat_errors(self.unsupported_messages, self.max_unsupported) - ) + raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) return sql @@ -197,9 +200,7 @@ class Generator: def wrap(self, expression): this_sql = self.indent( - self.sql(expression) - if isinstance(expression, (exp.Select, exp.Union)) - else self.sql(expression, "this"), + self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), level=1, pad=0, ) @@ -251,9 +252,7 @@ class Generator: return transform if not isinstance(expression, exp.Expression): - raise ValueError( - f"Expected an Expression. Received {type(expression)}: {expression}" - ) + raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") exp_handler_name = f"{expression.key}_sql" if hasattr(self, exp_handler_name): @@ -276,11 +275,7 @@ class Generator: lazy = " LAZY" if expression.args.get("lazy") else "" table = self.sql(expression, "this") options = expression.args.get("options") - options = ( - f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" - if options - else "" - ) + options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else "" sql = self.sql(expression, "expression") sql = f" AS{self.sep()}{sql}" if sql else "" sql = f"CACHE{lazy} TABLE {table}{options}{sql}" @@ -306,9 +301,7 @@ class Generator: def columndef_sql(self, expression): column = self.sql(expression, "this") kind = self.sql(expression, "kind") - constraints = self.expressions( - expression, key="constraints", sep=" ", flat=True - ) + constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) if not constraints: return f"{column} {kind}" @@ -338,6 +331,9 @@ class Generator: default = self.sql(expression, "this") return f"DEFAULT {default}" + def generatedasidentitycolumnconstraint_sql(self, expression): + return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" + def notnullcolumnconstraint_sql(self, _): return "NOT NULL" @@ -384,7 +380,10 @@ class Generator: return f"{alias}{columns}" def bitstring_sql(self, expression): - return f"b'{self.sql(expression, 'this')}'" + return self.sql(expression, "this") + + def hexstring_sql(self, expression): + return self.sql(expression, "this") def datatype_sql(self, expression): type_value = expression.this @@ -452,10 +451,7 @@ class Generator: def partition_sql(self, expression): keys = csv( - *[ - f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] - for k, v in expression.args.get("this") - ] + *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] ) return f"PARTITION({keys})" @@ -470,9 +466,9 @@ class Generator: elif p_class in self.WITH_PROPERTIES: with_properties.append(p) - return self.root_properties( - exp.Properties(expressions=root_properties) - ) + self.with_properties(exp.Properties(expressions=with_properties)) + return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( + exp.Properties(expressions=with_properties) + ) def root_properties(self, properties): if properties.expressions: @@ -508,11 +504,7 @@ class Generator: kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" this = self.sql(expression, "this") exists = " IF EXISTS " if expression.args.get("exists") else " " - partition_sql = ( - self.sql(expression, "partition") - if expression.args.get("partition") - else "" - ) + partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" @@ -531,7 +523,7 @@ class Generator: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" def table_sql(self, expression): - return ".".join( + table = ".".join( part for part in [ self.sql(expression, "catalog"), @@ -541,6 +533,10 @@ class Generator: if part ) + laterals = self.expressions(expression, key="laterals", sep="") + joins = self.expressions(expression, key="joins", sep="") + return f"{table}{laterals}{joins}" + def tablesample_sql(self, expression): if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): this = self.sql(expression.this, "this") @@ -586,11 +582,7 @@ class Generator: def group_sql(self, expression): group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) - grouping_sets = ( - f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" - if grouping_sets - else "" - ) + grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" cube = self.expressions(expression, key="cube", indent=False) cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" rollup = self.expressions(expression, key="rollup", indent=False) @@ -603,7 +595,16 @@ class Generator: def join_sql(self, expression): op_sql = self.seg( - " ".join(op for op in (expression.side, expression.kind, "JOIN") if op) + " ".join( + op + for op in ( + "NATURAL" if expression.args.get("natural") else None, + expression.side, + expression.kind, + "JOIN", + ) + if op + ) ) on_sql = self.sql(expression, "on") using = expression.args.get("using") @@ -630,9 +631,9 @@ class Generator: def lateral_sql(self, expression): this = self.sql(expression, "this") - op_sql = self.seg( - f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" - ) + if isinstance(expression.this, exp.Subquery): + return f"LATERAL{self.sep()}{this}" + op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") alias = expression.args["alias"] table = alias.name table = f" {table}" if table else table @@ -688,21 +689,13 @@ class Generator: sort_order = " DESC" if desc else "" nulls_sort_change = "" - if nulls_first and ( - (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last - ): + if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): nulls_sort_change = " NULLS FIRST" - elif ( - nulls_last - and ((asc and nulls_are_small) or (desc and nulls_are_large)) - and not nulls_are_last - ): + elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: nulls_sort_change = " NULLS LAST" if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported( - "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" - ) + self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") nulls_sort_change = "" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" @@ -798,14 +791,20 @@ class Generator: def window_sql(self, expression): this = self.sql(expression, "this") + partition = self.expressions(expression, key="partition_by", flat=True) partition = f"PARTITION BY {partition}" if partition else "" + order = expression.args.get("order") order_sql = self.order_sql(order, flat=True) if order else "" + partition_sql = partition + " " if partition and order else partition + spec = expression.args.get("spec") spec_sql = " " + self.window_spec_sql(spec) if spec else "" + alias = self.sql(expression, "alias") + if expression.arg_key == "window": this = this = f"{self.seg('WINDOW')} {this} AS" else: @@ -818,13 +817,8 @@ class Generator: def window_spec_sql(self, expression): kind = self.sql(expression, "kind") - start = csv( - self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " - ) - end = ( - csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") - or "CURRENT ROW" - ) + start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") + end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" return f"{kind} BETWEEN {start} AND {end}" def withingroup_sql(self, expression): @@ -879,6 +873,17 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" + def trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + + if trim_type == "LEADING": + return f"LTRIM({target})" + elif trim_type == "TRAILING": + return f"RTRIM({target})" + else: + return f"TRIM({target})" + def check_sql(self, expression): this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -898,9 +903,7 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql( - exp.Case(ifs=[expression], default=expression.args.get("false")) - ) + return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def in_sql(self, expression): query = expression.args.get("query") @@ -917,7 +920,9 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression): - return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}" + unit = self.sql(expression, "unit") + unit = f" {unit}" if unit else "" + return f"INTERVAL {self.sql(expression, 'this')}{unit}" def reference_sql(self, expression): this = self.sql(expression, "this") @@ -925,9 +930,7 @@ class Generator: return f"REFERENCES {this}({expressions})" def anonymous_sql(self, expression): - args = self.indent( - self.expressions(expression, flat=True), skip_first=True, skip_last=True - ) + args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" def paren_sql(self, expression): @@ -1006,6 +1009,9 @@ class Generator: def ignorenulls_sql(self, expression): return f"{self.sql(expression, 'this')} IGNORE NULLS" + def respectnulls_sql(self, expression): + return f"{self.sql(expression, 'this')} RESPECT NULLS" + def intdiv_sql(self, expression): return self.sql( exp.Cast( @@ -1023,6 +1029,9 @@ class Generator: def div_sql(self, expression): return self.binary(expression, "/") + def distance_sql(self, expression): + return self.binary(expression, "<->") + def dot_sql(self, expression): return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" @@ -1047,6 +1056,9 @@ class Generator: def like_sql(self, expression): return self.binary(expression, "LIKE") + def similarto_sql(self, expression): + return self.binary(expression, "SIMILAR TO") + def lt_sql(self, expression): return self.binary(expression, "<") @@ -1069,14 +1081,10 @@ class Generator: return self.binary(expression, "-") def trycast_sql(self, expression): - return ( - f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - ) + return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" def binary(self, expression, op): - return ( - f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" - ) + return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" def function_fallback_sql(self, expression): args = [] @@ -1089,9 +1097,7 @@ class Generator: return f"{self.normalize_func(expression.sql_name())}({args_str})" def format_time(self, expression): - return format_time( - self.sql(expression, "format"), self.time_mapping, self.time_trie - ) + return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): expressions = expression.args.get(key or "expressions") @@ -1102,7 +1108,14 @@ class Generator: if flat: return sep.join(self.sql(e) for e in expressions) - expressions = self.sep(sep).join(self.sql(e) for e in expressions) + sql = (self.sql(e) for e in expressions) + # the only time leading_comma changes the output is if pretty print is enabled + if self._leading_comma and self.pretty: + pad = " " * self.pad + expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql)) + else: + expressions = self.sep(sep).join(sql) + if indent: return self.indent(expressions, skip_first=False) return expressions @@ -1116,9 +1129,7 @@ class Generator: def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) - return self.query_modifiers( - expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" - ) + return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index a4c4cc2..d1146ca 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,2 +1,2 @@ -from sqlglot.optimizer.optimizer import optimize +from sqlglot.optimizer.optimizer import RULES, optimize from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index c2e021e..e060739 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -13,9 +13,7 @@ def isolate_table_selects(expression): continue if not isinstance(source.parent, exp.Alias): - raise OptimizeError( - "Tables require an alias. Run qualify_tables optimization." - ) + raise OptimizeError("Tables require an alias. Run qualify_tables optimization.") parent = source.parent diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_derived_tables.py new file mode 100644 index 0000000..8b161fb --- /dev/null +++ b/sqlglot/optimizer/merge_derived_tables.py @@ -0,0 +1,232 @@ +from collections import defaultdict + +from sqlglot import expressions as exp +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def merge_derived_tables(expression): + """ + Rewrite sqlglot AST to merge derived tables into the outer query. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") + >>> merge_derived_tables(expression).sql() + 'SELECT x.a FROM x' + + Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for outer_scope in traverse_scope(expression): + for subquery in outer_scope.derived_tables: + inner_select = subquery.unnest() + if ( + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and _mergeable(inner_select) + ): + alias = subquery.alias_or_name + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + inner_scope = outer_scope.sources[alias] + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, subquery) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def _mergeable(inner_select): + """ + Return True if `inner_select` can be merged into outer query. + + Args: + inner_select (exp.Select) + Returns: + bool: True if can be merged + """ + return ( + isinstance(inner_select, exp.Select) + and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) + and inner_select.args.get("from") + and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + ) + + +def _rename_inner_sources(outer_scope, inner_scope, alias): + """ + Renames any sources in the inner query that conflict with names in the outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + taken = set(outer_scope.selected_sources) + conflicts = taken.intersection(set(inner_scope.selected_sources)) + conflicts = conflicts - {alias} + + for conflict in conflicts: + new_name = _find_new_name(taken, conflict) + + source, _ = inner_scope.selected_sources[conflict] + new_alias = exp.to_identifier(new_name) + + if isinstance(source, exp.Subquery): + source.set("alias", exp.TableAlias(this=new_alias)) + elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): + source.parent.set("alias", new_alias) + elif isinstance(source, exp.Table): + source.replace(exp.alias_(source.copy(), new_alias)) + + for column in inner_scope.source_columns(conflict): + column.set("table", exp.to_identifier(new_name)) + + inner_scope.rename_source(conflict, new_name) + + +def _find_new_name(taken, base): + """ + Searches for a new source name. + + Args: + taken (set[str]): set of taken names + base (str): base name to alter + """ + i = 2 + new = f"{base}_{i}" + while new in taken: + i += 1 + new = f"{base}_{i}" + return new + + +def _merge_from(outer_scope, inner_scope, subquery): + """ + Merge FROM clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + subquery (exp.Subquery) + """ + new_subquery = inner_scope.expression.args.get("from").expressions[0] + subquery.replace(new_subquery) + outer_scope.remove_source(subquery.alias_or_name) + outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + + +def _merge_joins(outer_scope, inner_scope, from_or_join): + """ + Merge JOIN clauses of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + + new_joins = [] + comma_joins = inner_scope.expression.args.get("from").expressions[1:] + for subquery in comma_joins: + new_joins.append(exp.Join(this=subquery, kind="CROSS")) + outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) + + joins = inner_scope.expression.args.get("joins") or [] + for join in joins: + new_joins.append(join) + outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) + + if new_joins: + outer_joins = outer_scope.expression.args.get("joins", []) + + # Maintain the join order + if isinstance(from_or_join, exp.From): + position = 0 + else: + position = outer_joins.index(from_or_join) + 1 + outer_joins[position:position] = new_joins + + outer_scope.expression.set("joins", outer_joins) + + +def _merge_expressions(outer_scope, inner_scope, alias): + """ + Merge projections of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + # Collect all columns that for the alias of the inner query + outer_columns = defaultdict(list) + for column in outer_scope.columns: + if column.table == alias: + outer_columns[column.name].append(column) + + # Replace columns with the projection expression in the inner query + for expression in inner_scope.expression.expressions: + projection_name = expression.alias_or_name + if not projection_name: + continue + columns_to_replace = outer_columns.get(projection_name, []) + for column in columns_to_replace: + column.replace(expression.unalias()) + + +def _merge_where(outer_scope, inner_scope, from_or_join): + """ + Merge WHERE clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + where = inner_scope.expression.args.get("where") + if not where or not where.this: + return + + if isinstance(from_or_join, exp.Join) and from_or_join.side: + # Merge predicates from an outer join to the ON clause + from_or_join.on(where.this, copy=False) + from_or_join.set("on", simplify(from_or_join.args.get("on"))) + else: + outer_scope.expression.where(where.this, copy=False) + outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) + + +def _merge_order(outer_scope, inner_scope): + """ + Merge ORDER clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + """ + if ( + any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + or len(outer_scope.selected_sources) != 1 + or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) + ): + return + + outer_scope.expression.set("order", inner_scope.expression.args.get("order")) diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 2c9f89c..ab30d7a 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128): """ expression = simplify(expression) - expression = while_changing( - expression, lambda e: distributive_law(e, dnf, max_distance) - ) + expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance)) return simplify(expression) def normalized(expression, dnf=False): ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) - return not any( - connector.find_ancestor(ancestor) for connector in expression.find_all(root) - ) + return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) def normalization_distance(expression, dnf=False): @@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False): Returns: int: difference """ - return sum(_predicate_lengths(expression, dnf)) - ( - len(list(expression.find_all(exp.Connector))) + 1 - ) + return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) def _predicate_lengths(expression, dnf): @@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf): left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - x = [ - a + b - for a in _predicate_lengths(left, dnf) - for b in _predicate_lengths(right, dnf) - ] + x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] return x return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) @@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance): to_func = exp.and_ if to_exp == exp.And else exp.or_ if isinstance(a, to_exp) and isinstance(b, to_exp): - if len(tuple(a.find_all(exp.Connector))) > len( - tuple(b.find_all(exp.Connector)) - ): + if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): return _distribute(a, b, from_func, to_func) return _distribute(b, a, from_func, to_func) if isinstance(a, to_exp): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 40e4ab1..0c74e36 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -68,8 +68,4 @@ def normalize(expression): def other_table_names(join, exclude): - return [ - name - for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) - if name != exclude - ] + return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c03fe3c..c8c2403 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,6 +1,7 @@ from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.merge_derived_tables import merge_derived_tables from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates @@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries +RULES = ( + qualify_tables, + isolate_table_selects, + qualify_columns, + pushdown_projections, + normalize, + unnest_subqueries, + expand_multi_table_selects, + pushdown_predicates, + optimize_joins, + eliminate_subqueries, + merge_derived_tables, + quote_identities, +) -def optimize(expression, schema=None, db=None, catalog=None): + +def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs): """ Rewrite a sqlglot AST into an optimized form. @@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None): 3. {catalog: {db: {table: {col: type}}}} db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement + rules (list): sequence of optimizer rules to use + **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression """ + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() - expression = qualify_tables(expression, db=db, catalog=catalog) - expression = isolate_table_selects(expression) - expression = qualify_columns(expression, schema) - expression = pushdown_projections(expression) - expression = normalize(expression) - expression = unnest_subqueries(expression) - expression = expand_multi_table_selects(expression) - expression = pushdown_predicates(expression) - expression = optimize_joins(expression) - expression = eliminate_subqueries(expression) - expression = quote_identities(expression) + for rule in rules: + + # Find any additional rule parameters, beyond `expression` + rule_params = rule.__code__.co_varnames + rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + + expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index e757322..a070d70 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -42,11 +42,7 @@ def pushdown(condition, sources): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list( - condition.flatten() - if isinstance(condition, exp.And if cnf_like else exp.Or) - else [condition] - ) + predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: pushdown_cnf(predicates, sources) @@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = ( - exp.and_(predicate_condition, condition) - if predicate_condition - else condition - ) + predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) - if table in conditions - else predicate_condition + exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition ) for name, node in nodes.items(): @@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope): def nodes_for_predicate(predicate, sources): nodes = {} tables = exp.column_table_names(predicate) - where_condition = isinstance( - predicate.find_ancestor(exp.Join, exp.Where), exp.Where - ) + where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) for table in tables: node, source = sources.get(table) or (None, None) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 394f49e..0bb947a 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -226,9 +226,7 @@ def _expand_stars(scope, resolver): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif isinstance(expression, exp.Column) and isinstance( - expression.this, exp.Star - ): + elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -245,9 +243,7 @@ def _expand_stars(scope, resolver): if name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table) - new_selections.append( - alias(column, alias_) if alias_ != name else column - ) + new_selections.append(alias(column, alias_) if alias_ != name else column) scope.expression.set("expressions", new_selections) @@ -280,9 +276,7 @@ def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = [] - for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.selects, scope.outer_column_list) - ): + for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): if isinstance(selection, exp.Column): # convoluted setter because a simple selection.replace(alias) would require a copy alias_ = alias(exp.column(""), alias=selection.name) @@ -302,11 +296,7 @@ def _qualify_outputs(scope): def _check_unknown_tables(scope): - if ( - scope.external_columns - and not scope.is_unnest - and not scope.is_correlated_subquery - ): + if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery: raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") @@ -334,20 +324,14 @@ class _Resolver: (str) table name """ if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns( - self._get_all_source_columns() - ) + self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) return self._unambiguous_columns.get(column_name) @property def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set( - column - for columns in self._get_all_source_columns().values() - for column in columns - ) + self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) return self._all_columns def get_source_columns(self, name): @@ -369,9 +353,7 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: - self._source_columns = { - k: self.get_source_columns(k) for k in self.scope.selected_sources - } + self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} return self._source_columns def _get_unambiguous_columns(self, source_columns): @@ -389,9 +371,7 @@ class _Resolver: source_columns = list(source_columns.items()) first_table, first_columns = source_columns[0] - unambiguous_columns = { - col: first_table for col in self._find_unique_columns(first_columns) - } + unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} all_columns = set(unambiguous_columns) for table, columns in source_columns[1:]: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 9f8b9f5..30e93ba 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None): for derived_table in scope.ctes + scope.derived_tables: if not derived_table.args.get("alias"): alias_ = f"_q_{next(sequence)}" - derived_table.set( - "alias", exp.TableAlias(this=exp.to_identifier(alias_)) - ) + derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) for source in scope.sources.values(): diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 9968108..1761228 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -57,9 +57,7 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): - raise ValueError( - f"Schema doesn't support {forbidden}. Received: {table.sql()}" - ) + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index f6f59e8..e816e10 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -104,9 +104,7 @@ class Scope: elif isinstance(node, exp.CTE): self._ctes.append(node) prune = True - elif isinstance(node, exp.Subquery) and isinstance( - parent, (exp.From, exp.Join) - ): + elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): self._derived_tables.append(node) prune = True elif isinstance(node, exp.Subqueryable): @@ -195,20 +193,14 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [ - column - for scope in self.subquery_scopes - for column in scope.external_columns - ] + external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] named_outputs = {e.alias_or_name for e in self.expression.expressions} self._columns = [ c for c in columns + external_columns - if not ( - c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs - ) + if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) ] return self._columns @@ -229,9 +221,7 @@ class Scope: for table in self.tables: referenced_names.append( ( - table.parent.alias - if isinstance(table.parent, exp.Alias) - else table.name, + table.parent.alias if isinstance(table.parent, exp.Alias) else table.name, table, ) ) @@ -274,9 +264,7 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [ - c for c in self.columns if c.table not in self.selected_sources - ] + self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns def source_columns(self, source_name): @@ -310,6 +298,16 @@ class Scope: columns = self.sources.pop(old_name or "", []) self.sources[new_name] = columns + def add_source(self, name, source): + """Add a source to this scope""" + self.sources[name] = source + self.clear_cache() + + def remove_source(self, name): + """Remove a source from this scope""" + self.sources.pop(name, None) + self.clear_cache() + def traverse_scope(expression): """ @@ -334,7 +332,7 @@ def traverse_scope(expression): Args: expression (exp.Expression): expression to traverse Returns: - List[Scope]: scope instances + list[Scope]: scope instances """ return list(_traverse_scope(Scope(expression))) @@ -356,9 +354,7 @@ def _traverse_scope(scope): def _traverse_select(scope): yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) yield from _traverse_subqueries(scope) - yield from _traverse_derived_tables( - scope.derived_tables, scope, ScopeType.DERIVED_TABLE - ) + yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) _add_table_sources(scope) @@ -367,15 +363,11 @@ def _traverse_union(scope): # The last scope to be yield should be the top most scope left = None - for left in _traverse_scope( - scope.branch(scope.expression.left, scope_type=ScopeType.UNION) - ): + for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): yield left right = None - for right in _traverse_scope( - scope.branch(scope.expression.right, scope_type=ScopeType.UNION) - ): + for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right scope.union = (left, right) @@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): for derived_table in derived_tables: for child_scope in _traverse_scope( scope.branch( - derived_table - if isinstance(derived_table, (exp.Unnest, exp.Lateral)) - else derived_table.this, + derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, add_sources=sources if scope_type == ScopeType.CTE else None, outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UNNEST - if isinstance(derived_table, exp.Unnest) - else scope_type, + scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type, ) ): yield child_scope @@ -430,9 +418,7 @@ def _add_table_sources(scope): def _traverse_subqueries(scope): for subquery in scope.subqueries: top = None - for child_scope in _traverse_scope( - scope.branch(subquery, scope_type=ScopeType.SUBQUERY) - ): + for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): yield child_scope top = child_scope scope.subquery_scopes.append(top) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6771153..319e6b6 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -188,9 +188,7 @@ def absorb_and_eliminate(expression): aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) elif is_complement(b, ab): ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) - elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( - a.flatten() - ): + elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): a.replace(exp.FALSE if kind == exp.And else exp.TRUE) elif isinstance(b, kind): # eliminate @@ -227,9 +225,7 @@ def simplify_literals(expression): operands.append(a) if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) + return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 55c81c5..11c6eba 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence): return if isinstance(predicate, exp.Binary): - key = ( - predicate.right - if any(node is column for node, *_ in predicate.left.walk()) - else predicate.left - ) + key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left else: return @@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence): # if the value of the subquery is not an agg or a key, we need to collect it into an array # so that it can be grouped if not value.find(exp.AggFunc) and value.this not in group_by: - select.select( - f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False - ) + select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence): else: parent_predicate = _replace(parent_predicate, "TRUE") elif isinstance(parent_predicate, exp.All): - parent_predicate = _replace( - parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" - ) + parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") elif isinstance(parent_predicate, exp.Any): if value.this in group_by: parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") else: - parent_predicate = _replace( - parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" - ) + parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") elif isinstance(parent_predicate, exp.In): if value.this in group_by: parent_predicate = _replace(parent_predicate, f"{other} = {alias}") @@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace( - parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" - ) + parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9396c50..f46bafe 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -78,6 +78,7 @@ class Parser: TokenType.TEXT, TokenType.BINARY, TokenType.JSON, + TokenType.INTERVAL, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.DATETIME, @@ -85,6 +86,12 @@ class Parser: TokenType.DECIMAL, TokenType.UUID, TokenType.GEOGRAPHY, + TokenType.GEOMETRY, + TokenType.HLLSKETCH, + TokenType.SUPER, + TokenType.SERIAL, + TokenType.SMALLSERIAL, + TokenType.BIGSERIAL, *NESTED_TYPE_TOKENS, } @@ -100,13 +107,14 @@ class Parser: ID_VAR_TOKENS = { TokenType.VAR, TokenType.ALTER, + TokenType.ALWAYS, TokenType.BEGIN, + TokenType.BOTH, TokenType.BUCKET, TokenType.CACHE, TokenType.COLLATE, TokenType.COMMIT, TokenType.CONSTRAINT, - TokenType.CONVERT, TokenType.DEFAULT, TokenType.DELETE, TokenType.ENGINE, @@ -115,14 +123,19 @@ class Parser: TokenType.FALSE, TokenType.FIRST, TokenType.FOLLOWING, + TokenType.FOR, TokenType.FORMAT, TokenType.FUNCTION, + TokenType.GENERATED, + TokenType.IDENTITY, TokenType.IF, TokenType.INDEX, TokenType.ISNULL, TokenType.INTERVAL, TokenType.LAZY, + TokenType.LEADING, TokenType.LOCATION, + TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, TokenType.OPTIMIZE, @@ -141,6 +154,7 @@ class Parser: TokenType.TABLE_FORMAT, TokenType.TEMPORARY, TokenType.TOP, + TokenType.TRAILING, TokenType.TRUNCATE, TokenType.TRUE, TokenType.UNBOUNDED, @@ -150,18 +164,15 @@ class Parser: *TYPE_TOKENS, } - CASTS = { - TokenType.CAST, - TokenType.TRY_CAST, - } + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL} + + TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { - TokenType.CONVERT, TokenType.CURRENT_DATE, TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIME, - TokenType.EXTRACT, TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, @@ -178,7 +189,6 @@ class Parser: TokenType.DATETIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, - *CASTS, *NESTED_TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -215,6 +225,7 @@ class Parser: FACTOR = { TokenType.DIV: exp.IntDiv, + TokenType.LR_ARROW: exp.Distance, TokenType.SLASH: exp.Div, TokenType.STAR: exp.Mul, } @@ -299,14 +310,13 @@ class Parser: PRIMARY_PARSERS = { TokenType.STRING: lambda _, token: exp.Literal.string(token.text), TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), - TokenType.STAR: lambda self, _: exp.Star( - **{"except": self._parse_except(), "replace": self._parse_replace()} - ), + TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), TokenType.NULL: lambda *_: exp.Null(), TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), + TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), TokenType.INTRODUCER: lambda self, token: self.expression( exp.Introducer, this=token.text, @@ -319,13 +329,16 @@ class Parser: TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: lambda self, this: self._parse_escape( - self.expression(exp.Like, this=this, expression=self._parse_type()) + self.expression(exp.Like, this=this, expression=self._parse_bitwise()) ), TokenType.ILIKE: lambda self, this: self._parse_escape( - self.expression(exp.ILike, this=this, expression=self._parse_type()) + self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) ), TokenType.RLIKE: lambda self, this: self.expression( - exp.RegexpLike, this=this, expression=self._parse_type() + exp.RegexpLike, this=this, expression=self._parse_bitwise() + ), + TokenType.SIMILAR_TO: lambda self, this: self.expression( + exp.SimilarTo, this=this, expression=self._parse_bitwise() ), } @@ -363,28 +376,21 @@ class Parser: } FUNCTION_PARSERS = { - TokenType.CONVERT: lambda self, _: self._parse_convert(), - TokenType.EXTRACT: lambda self, _: self._parse_extract(), - **{ - token_type: lambda self, token_type: self._parse_cast( - self.STRICT_CAST and token_type == TokenType.CAST - ) - for token_type in CASTS - }, + "CONVERT": lambda self: self._parse_convert(), + "EXTRACT": lambda self: self._parse_extract(), + "SUBSTRING": lambda self: self._parse_substring(), + "TRIM": lambda self: self._parse_trim(), + "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "TRY_CAST": lambda self: self._parse_cast(False), } QUERY_MODIFIER_PARSERS = { - "laterals": lambda self: self._parse_laterals(), - "joins": lambda self: self._parse_joins(), "where": lambda self: self._parse_where(), "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) - and self._parse_window(self._parse_id_var(), alias=True), - "distribute": lambda self: self._parse_sort( - TokenType.DISTRIBUTE_BY, exp.Distribute - ), + "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), + "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "order": lambda self: self._parse_order(), @@ -392,6 +398,8 @@ class Parser: "offset": lambda self: self._parse_offset(), } + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) + CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} STRICT_CAST = True @@ -457,9 +465,7 @@ class Parser: Returns the list of syntax trees (:class:`~sqlglot.expressions.Expression`). """ - return self._parse( - parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql - ) + return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) def parse_into(self, expression_types, raw_tokens, sql=None): for expression_type in ensure_list(expression_types): @@ -532,21 +538,13 @@ class Parser: for k in expression.args: if k not in expression.arg_types: - self.raise_error( - f"Unexpected keyword: '{k}' for {expression.__class__}" - ) + self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}") for k, mandatory in expression.arg_types.items(): v = expression.args.get(k) if mandatory and (v is None or (isinstance(v, list) and not v)): - self.raise_error( - f"Required keyword: '{k}' missing for {expression.__class__}" - ) + self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - if ( - args - and len(args) > len(expression.arg_types) - and not expression.is_var_len_args - ): + if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args: self.raise_error( f"The number of provided arguments ({len(args)}) is greater than " f"the maximum number of supported arguments ({len(expression.arg_types)})" @@ -594,11 +592,7 @@ class Parser: ) expression = self._parse_expression() - expression = ( - self._parse_set_operations(expression) - if expression - else self._parse_select() - ) + expression = self._parse_set_operations(expression) if expression else self._parse_select() self._parse_query_modifiers(expression) return expression @@ -618,11 +612,7 @@ class Parser: ) def _parse_exists(self, not_=False): - return ( - self._match(TokenType.IF) - and (not not_ or self._match(TokenType.NOT)) - and self._match(TokenType.EXISTS) - ) + return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) @@ -647,11 +637,9 @@ class Parser: this = self._parse_index() elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): this = self._parse_table(schema=True) - properties = self._parse_properties( - this if isinstance(this, exp.Schema) else None - ) + properties = self._parse_properties(this if isinstance(this, exp.Schema) else None) if self._match(TokenType.ALIAS): - expression = self._parse_select() + expression = self._parse_select(nested=True) return self.expression( exp.Create, @@ -682,9 +670,7 @@ class Parser: if schema and not isinstance(value, exp.Schema): columns = {v.name.upper() for v in value.expressions} partitions = [ - expression - for expression in schema.expressions - if expression.this.name.upper() in columns + expression for expression in schema.expressions if expression.this.name.upper() in columns ] schema.set( "expressions", @@ -811,7 +797,7 @@ class Parser: this=self._parse_table(schema=True), exists=self._parse_exists(), partition=self._parse_partition(), - expression=self._parse_select(), + expression=self._parse_select(nested=True), overwrite=overwrite, ) @@ -829,8 +815,7 @@ class Parser: exp.Update, **{ "this": self._parse_table(schema=True), - "expressions": self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), + "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), "where": self._parse_where(), }, @@ -865,7 +850,7 @@ class Parser: this=table, lazy=lazy, options=options, - expression=self._parse_select(), + expression=self._parse_select(nested=True), ) def _parse_partition(self): @@ -894,9 +879,7 @@ class Parser: self._match_r_paren() return self.expression(exp.Tuple, expressions=expressions) - def _parse_select(self, table=None): - index = self._index - + def _parse_select(self, nested=False, table=False): if self._match(TokenType.SELECT): hint = self._parse_hint() all_ = self._match(TokenType.ALL) @@ -912,9 +895,7 @@ class Parser: self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv( - lambda: self._parse_annotation(self._parse_expression()) - ) + expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) this = self.expression( exp.Select, @@ -960,19 +941,13 @@ class Parser: ) else: self.raise_error(f"{this.key} does not support CTE") - elif self._match(TokenType.L_PAREN): - this = self._parse_table() if table else self._parse_select() - - if this: - self._parse_query_modifiers(this) - self._match_r_paren() - this = self._parse_subquery(this) - else: - self._retreat(index) + elif (table or nested) and self._match(TokenType.L_PAREN): + this = self._parse_table() if table else self._parse_select(nested=True) + self._parse_query_modifiers(this) + self._match_r_paren() + this = self._parse_subquery(this) elif self._match(TokenType.VALUES): - this = self.expression( - exp.Values, expressions=self._parse_csv(self._parse_value) - ) + this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value)) alias = self._parse_table_alias() if alias: this = self.expression(exp.Subquery, this=this, alias=alias) @@ -1001,7 +976,7 @@ class Parser: def _parse_table_alias(self): any_token = self._match(TokenType.ALIAS) - alias = self._parse_id_var(any_token) + alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS) columns = None if self._match(TokenType.L_PAREN): @@ -1021,9 +996,24 @@ class Parser: return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) def _parse_query_modifiers(self, this): - if not isinstance(this, (exp.Subquery, exp.Subqueryable)): + if not isinstance(this, self.MODIFIABLES): return + table = isinstance(this, exp.Table) + + while True: + lateral = self._parse_lateral() + join = self._parse_join() + comma = None if table else self._match(TokenType.COMMA) + if lateral: + this.append("laterals", lateral) + if join: + this.append("joins", join) + if comma: + this.args["from"].append("expressions", self._parse_table()) + if not (lateral or join or comma): + break + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): expression = parser(self) @@ -1032,9 +1022,7 @@ class Parser: def _parse_annotation(self, expression): if self._match(TokenType.ANNOTATION): - return self.expression( - exp.Annotation, this=self._prev.text, expression=expression - ) + return self.expression(exp.Annotation, this=self._prev.text, expression=expression) return expression @@ -1052,16 +1040,16 @@ class Parser: return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) - def _parse_laterals(self): - return self._parse_all(self._parse_lateral) - def _parse_lateral(self): if not self._match(TokenType.LATERAL): return None - if not self._match(TokenType.VIEW): - self.raise_error("Expected VIEW after LATERAL") + subquery = self._parse_select(table=True) + if subquery: + return self.expression(exp.Lateral, this=subquery) + + self._match(TokenType.VIEW) outer = self._match(TokenType.OUTER) return self.expression( @@ -1071,31 +1059,27 @@ class Parser: alias=self.expression( exp.TableAlias, this=self._parse_id_var(any_token=False), - columns=( - self._parse_csv(self._parse_id_var) - if self._match(TokenType.ALIAS) - else None - ), + columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None), ), ) - def _parse_joins(self): - return self._parse_all(self._parse_join) - def _parse_join_side_and_kind(self): return ( + self._match(TokenType.NATURAL) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) def _parse_join(self): - side, kind = self._parse_join_side_and_kind() + natural, side, kind = self._parse_join_side_and_kind() if not self._match(TokenType.JOIN): return None kwargs = {"this": self._parse_table()} + if natural: + kwargs["natural"] = True if side: kwargs["side"] = side.text if kind: @@ -1120,6 +1104,11 @@ class Parser: ) def _parse_table(self, schema=False): + lateral = self._parse_lateral() + + if lateral: + return lateral + unnest = self._parse_unnest() if unnest: @@ -1172,9 +1161,7 @@ class Parser: expressions = self._parse_csv(self._parse_column) self._match_r_paren() - ordinality = bool( - self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY) - ) + ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) alias = self._parse_table_alias() @@ -1280,17 +1267,13 @@ class Parser: if not self._match(TokenType.ORDER_BY): return this - return self.expression( - exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) - ) + return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) def _parse_sort(self, token_type, exp_class): if not self._match(token_type): return None - return self.expression( - exp_class, expressions=self._parse_csv(self._parse_ordered) - ) + return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) def _parse_ordered(self): this = self._parse_conjunction() @@ -1305,22 +1288,17 @@ class Parser: if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") - or (desc and self.null_ordering != "nulls_are_small") + (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") ) and self.null_ordering != "nulls_are_last" ): nulls_first = True - return self.expression( - exp.Ordered, this=this, desc=desc, nulls_first=nulls_first - ) + return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) def _parse_limit(self, this=None, top=False): if self._match(TokenType.TOP if top else TokenType.LIMIT): - return self.expression( - exp.Limit, this=this, expression=self._parse_number() - ) + return self.expression(exp.Limit, this=this, expression=self._parse_number()) if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -1354,7 +1332,7 @@ class Parser: expression, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), - expression=self._parse_select(), + expression=self._parse_select(nested=True), ) def _parse_expression(self): @@ -1396,9 +1374,7 @@ class Parser: this = self.expression(exp.In, this=this, unnest=unnest) else: self._match_l_paren() - expressions = self._parse_csv( - lambda: self._parse_select() or self._parse_expression() - ) + expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression()) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): this = self.expression(exp.In, this=this, query=expressions[0]) @@ -1430,13 +1406,9 @@ class Parser: expression=self._parse_term(), ) elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression( - exp.BitwiseLeftShift, this=this, expression=self._parse_term() - ) + this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression( - exp.BitwiseRightShift, this=this, expression=self._parse_term() - ) + this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) else: break @@ -1524,7 +1496,7 @@ class Parser: self.raise_error("Expecting >") if type_token in self.TIMESTAMPS: - tz = self._match(TokenType.WITH_TIME_ZONE) + tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ self._match(TokenType.WITHOUT_TIME_ZONE) if tz: return exp.DataType( @@ -1594,16 +1566,14 @@ class Parser: if query: expressions = [query] else: - expressions = self._parse_csv( - lambda: self._parse_alias(self._parse_conjunction(), explicit=True) - ) + expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) this = list_get(expressions, 0) self._parse_query_modifiers(this) self._match_r_paren() if isinstance(this, exp.Subqueryable): - return self._parse_subquery(this) + return self._parse_set_operations(self._parse_subquery(this)) if len(expressions) > 1: return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Paren, this=this) @@ -1611,11 +1581,7 @@ class Parser: return None def _parse_field(self, any_token=False): - return ( - self._parse_primary() - or self._parse_function() - or self._parse_id_var(any_token) - ) + return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) def _parse_function(self): if not self._curr: @@ -1628,21 +1594,22 @@ class Parser: if not self._next or self._next.token_type != TokenType.L_PAREN: if token_type in self.NO_PAREN_FUNCTIONS: - return self.expression( - self._advance() or self.NO_PAREN_FUNCTIONS[token_type] - ) + return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type]) return None if token_type not in self.FUNC_TOKENS: return None - if self._match_set(self.FUNCTION_PARSERS): - self._advance() - this = self.FUNCTION_PARSERS[token_type](self, token_type) + this = self._curr.text + upper = this.upper() + self._advance(2) + + parser = self.FUNCTION_PARSERS.get(upper) + + if parser: + this = parser(self) else: subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - this = self._curr.text - self._advance(2) if subquery_predicate and self._curr.token_type in ( TokenType.SELECT, @@ -1652,7 +1619,7 @@ class Parser: self._match_r_paren() return this - function = self.FUNCTIONS.get(this.upper()) + function = self.FUNCTIONS.get(upper) args = self._parse_csv(self._parse_lambda) if function: @@ -1700,10 +1667,7 @@ class Parser: self._retreat(index) return this - args = self._parse_csv( - lambda: self._parse_constraint() - or self._parse_column_def(self._parse_field()) - ) + args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -1720,12 +1684,9 @@ class Parser: break constraints.append(constraint) - return self.expression( - exp.ColumnDef, this=this, kind=kind, constraints=constraints - ) + return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) def _parse_column_constraint(self): - kind = None this = None if self._match(TokenType.CONSTRAINT): @@ -1735,28 +1696,28 @@ class Parser: kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): self._match_l_paren() - kind = self.expression( - exp.CheckColumnConstraint, this=self._parse_conjunction() - ) + kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction()) self._match_r_paren() elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression( - exp.DefaultColumnConstraint, this=self._parse_field() - ) - elif self._match(TokenType.NOT) and self._match(TokenType.NULL): + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field()) + elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.SCHEMA_COMMENT): - kind = self.expression( - exp.CommentColumnConstraint, this=self._parse_string() - ) + kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): kind = exp.PrimaryKeyColumnConstraint() elif self._match(TokenType.UNIQUE): kind = exp.UniqueColumnConstraint() - - if kind is None: + elif self._match(TokenType.GENERATED): + if self._match(TokenType.BY_DEFAULT): + kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) + else: + self._match(TokenType.ALWAYS) + kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) + self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) + else: return None return self.expression(exp.ColumnConstraint, this=this, kind=kind) @@ -1864,9 +1825,7 @@ class Parser: if not self._match(TokenType.END): self.raise_error("Expected END after CASE", self._prev) - return self._parse_window( - self.expression(exp.Case, this=expression, ifs=ifs, default=default) - ) + return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) def _parse_if(self): if self._match(TokenType.L_PAREN): @@ -1889,7 +1848,7 @@ class Parser: if not self._match(TokenType.FROM): self.raise_error("Expected FROM after EXTRACT", self._prev) - return self.expression(exp.Extract, this=this, expression=self._parse_type()) + return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) def _parse_cast(self, strict): this = self._parse_conjunction() @@ -1917,12 +1876,54 @@ class Parser: to = None return self.expression(exp.Cast, this=this, to=to) + def _parse_substring(self): + # Postgres supports the form: substring(string [from int] [for int]) + # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 + + args = self._parse_csv(self._parse_bitwise) + + if self._match(TokenType.FROM): + args.append(self._parse_bitwise()) + if self._match(TokenType.FOR): + args.append(self._parse_bitwise()) + + this = exp.Substring.from_arg_list(args) + self.validate_expression(this, args) + + return this + + def _parse_trim(self): + # https://www.w3resource.com/sql/character-functions/trim.php + # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html + + position = None + collation = None + + if self._match_set(self.TRIM_TYPES): + position = self._prev.text.upper() + + expression = self._parse_term() + if self._match(TokenType.FROM): + this = self._parse_term() + else: + this = expression + expression = None + + if self._match(TokenType.COLLATE): + collation = self._parse_term() + + return self.expression( + exp.Trim, + this=this, + position=position, + expression=expression, + collation=collation, + ) + def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): self._match_l_paren() - this = self.expression( - exp.Filter, this=this, expression=self._parse_where() - ) + this = self.expression(exp.Filter, this=this, expression=self._parse_where()) self._match_r_paren() if self._match(TokenType.WITHIN_GROUP): @@ -1935,6 +1936,25 @@ class Parser: self._match_r_paren() return this + # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER + # Some dialects choose to implement and some do not. + # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html + + # There is some code above in _parse_lambda that handles + # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ... + + # The below changes handle + # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ... + + # Oracle allows both formats + # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) + # and Snowflake chose to do the same for familiarity + # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + elif self._match(TokenType.RESPECT_NULLS): + this = self.expression(exp.RespectNulls, this=this) + # bigquery select from window x AS (partition by ...) if alias: self._match(TokenType.ALIAS) @@ -1992,13 +2012,9 @@ class Parser: self._match(TokenType.BETWEEN) return { - "value": ( - self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) - and self._prev.text - ) + "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) or self._parse_bitwise(), - "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) - and self._prev.text, + "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } def _parse_alias(self, this, explicit=False): @@ -2023,22 +2039,16 @@ class Parser: return this - def _parse_id_var(self, any_token=True): + def _parse_id_var(self, any_token=True, tokens=None): identifier = self._parse_identifier() if identifier: return identifier - if ( - any_token - and self._curr - and self._curr.token_type not in self.RESERVED_KEYWORDS - ): + if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) - return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier( - this=self._prev.text, quoted=False - ) + return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) def _parse_string(self): if self._match(TokenType.STRING): @@ -2077,9 +2087,7 @@ class Parser: def _parse_star(self): if self._match(TokenType.STAR): - return exp.Star( - **{"except": self._parse_except(), "replace": self._parse_replace()} - ) + return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) return None def _parse_placeholder(self): @@ -2117,15 +2125,10 @@ class Parser: this = parse() while self._match_set(expressions): - this = self.expression( - expressions[self._prev.token_type], this=this, expression=parse() - ) + this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) return this - def _parse_all(self, parse): - return list(iter(parse, None)) - def _parse_wrapped_id_vars(self): self._match_l_paren() expressions = self._parse_csv(self._parse_id_var) @@ -2156,10 +2159,7 @@ class Parser: if not self._curr or not self._next: return None - if ( - self._curr.token_type == token_type_a - and self._next.token_type == token_type_b - ): + if self._curr.token_type == token_type_a and self._next.token_type == token_type_b: if advance: self._advance(2) return True diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 2006a75..ed0b66c 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -72,9 +72,7 @@ class Step: if from_: from_ = from_.expressions if len(from_) > 1: - raise UnsupportedError( - "Multi-from statements are unsupported. Run it through the optimizer" - ) + raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer") step = Scan.from_expression(from_[0], ctes) else: @@ -104,9 +102,7 @@ class Step: continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" - operand.replace( - exp.column(operands[operand], step.name, quoted=True) - ) + operand.replace(exp.column(operands[operand], step.name, quoted=True)) else: projections.append(e) @@ -121,14 +117,9 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name - aggregate.operands = tuple( - alias(operand, alias_) for operand, alias_ in operands.items() - ) + aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) aggregate.aggregations = aggregations - aggregate.group = [ - exp.column(e.alias_or_name, step.name, quoted=True) - for e in group.expressions - ] + aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions] aggregate.add_dependency(step) step = aggregate @@ -212,9 +203,7 @@ class Scan(Step): alias_ = expression.alias if not alias_: - raise UnsupportedError( - "Tables/Subqueries must be aliased. Run it through the optimizer" - ) + raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer") if isinstance(expression, exp.Subquery): step = Step.from_expression(table, ctes) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index e4b754d..bd95bc7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -38,6 +38,7 @@ class TokenType(AutoName): DARROW = auto() HASH_ARROW = auto() DHASH_ARROW = auto() + LR_ARROW = auto() ANNOTATION = auto() DOLLAR = auto() @@ -53,6 +54,7 @@ class TokenType(AutoName): TABLE = auto() VAR = auto() BIT_STRING = auto() + HEX_STRING = auto() # types BOOLEAN = auto() @@ -78,10 +80,17 @@ class TokenType(AutoName): UUID = auto() GEOGRAPHY = auto() NULLABLE = auto() + GEOMETRY = auto() + HLLSKETCH = auto() + SUPER = auto() + SERIAL = auto() + SMALLSERIAL = auto() + BIGSERIAL = auto() # keywords ADD_FILE = auto() ALIAS = auto() + ALWAYS = auto() ALL = auto() ALTER = auto() ANALYZE = auto() @@ -92,11 +101,12 @@ class TokenType(AutoName): AUTO_INCREMENT = auto() BEGIN = auto() BETWEEN = auto() + BOTH = auto() BUCKET = auto() + BY_DEFAULT = auto() CACHE = auto() CALL = auto() CASE = auto() - CAST = auto() CHARACTER_SET = auto() CHECK = auto() CLUSTER_BY = auto() @@ -104,7 +114,6 @@ class TokenType(AutoName): COMMENT = auto() COMMIT = auto() CONSTRAINT = auto() - CONVERT = auto() CREATE = auto() CROSS = auto() CUBE = auto() @@ -127,22 +136,24 @@ class TokenType(AutoName): EXCEPT = auto() EXISTS = auto() EXPLAIN = auto() - EXTRACT = auto() FALSE = auto() FETCH = auto() FILTER = auto() FINAL = auto() FIRST = auto() FOLLOWING = auto() + FOR = auto() FOREIGN_KEY = auto() FORMAT = auto() FULL = auto() FUNCTION = auto() FROM = auto() + GENERATED = auto() GROUP_BY = auto() GROUPING_SETS = auto() HAVING = auto() HINT = auto() + IDENTITY = auto() IF = auto() IGNORE_NULLS = auto() ILIKE = auto() @@ -159,12 +170,14 @@ class TokenType(AutoName): JOIN = auto() LATERAL = auto() LAZY = auto() + LEADING = auto() LEFT = auto() LIKE = auto() LIMIT = auto() LOCATION = auto() MAP = auto() MOD = auto() + NATURAL = auto() NEXT = auto() NO_ACTION = auto() NULL = auto() @@ -204,8 +217,10 @@ class TokenType(AutoName): ROWS = auto() SCHEMA_COMMENT = auto() SELECT = auto() + SEPARATOR = auto() SET = auto() SHOW = auto() + SIMILAR_TO = auto() SOME = auto() SORT_BY = auto() STORED = auto() @@ -213,12 +228,11 @@ class TokenType(AutoName): TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() - TIME = auto() TOP = auto() THEN = auto() TRUE = auto() + TRAILING = auto() TRUNCATE = auto() - TRY_CAST = auto() UNBOUNDED = auto() UNCACHE = auto() UNION = auto() @@ -272,35 +286,32 @@ class _Tokenizer(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) - klass.QUOTES = dict( - (quote, quote) if isinstance(quote, str) else (quote[0], quote[1]) - for quote in klass.QUOTES - ) - - klass.IDENTIFIERS = dict( - (identifier, identifier) - if isinstance(identifier, str) - else (identifier[0], identifier[1]) - for identifier in klass.IDENTIFIERS - ) - - klass.COMMENTS = dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) - for comment in klass.COMMENTS + klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) + klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) + klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) + klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) + klass._COMMENTS = dict( + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS ) klass.KEYWORD_TRIE = new_trie( key.upper() for key, value in { **klass.KEYWORDS, - **{comment: TokenType.COMMENT for comment in klass.COMMENTS}, - **{quote: TokenType.QUOTE for quote in klass.QUOTES}, + **{comment: TokenType.COMMENT for comment in klass._COMMENTS}, + **{quote: TokenType.QUOTE for quote in klass._QUOTES}, + **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, + **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, }.items() if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) return klass + @staticmethod + def _delimeter_list_to_dict(list): + return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) + class Tokenizer(metaclass=_Tokenizer): SINGLE_TOKENS = { @@ -339,6 +350,10 @@ class Tokenizer(metaclass=_Tokenizer): QUOTES = ["'"] + BIT_STRINGS = [] + + HEX_STRINGS = [] + IDENTIFIERS = ['"'] ESCAPE = "'" @@ -357,6 +372,7 @@ class Tokenizer(metaclass=_Tokenizer): "->>": TokenType.DARROW, "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, + "<->": TokenType.LR_ARROW, "ADD ARCHIVE": TokenType.ADD_FILE, "ADD ARCHIVES": TokenType.ADD_FILE, "ADD FILE": TokenType.ADD_FILE, @@ -374,12 +390,12 @@ class Tokenizer(metaclass=_Tokenizer): "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "BEGIN": TokenType.BEGIN, "BETWEEN": TokenType.BETWEEN, + "BOTH": TokenType.BOTH, "BUCKET": TokenType.BUCKET, "CALL": TokenType.CALL, "CACHE": TokenType.CACHE, "UNCACHE": TokenType.UNCACHE, "CASE": TokenType.CASE, - "CAST": TokenType.CAST, "CHARACTER SET": TokenType.CHARACTER_SET, "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, @@ -387,7 +403,6 @@ class Tokenizer(metaclass=_Tokenizer): "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, "CONSTRAINT": TokenType.CONSTRAINT, - "CONVERT": TokenType.CONVERT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, "CUBE": TokenType.CUBE, @@ -408,7 +423,6 @@ class Tokenizer(metaclass=_Tokenizer): "EXCEPT": TokenType.EXCEPT, "EXISTS": TokenType.EXISTS, "EXPLAIN": TokenType.EXPLAIN, - "EXTRACT": TokenType.EXTRACT, "FALSE": TokenType.FALSE, "FETCH": TokenType.FETCH, "FILTER": TokenType.FILTER, @@ -437,10 +451,12 @@ class Tokenizer(metaclass=_Tokenizer): "JOIN": TokenType.JOIN, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, + "LEADING": TokenType.LEADING, "LEFT": TokenType.LEFT, "LIKE": TokenType.LIKE, "LIMIT": TokenType.LIMIT, "LOCATION": TokenType.LOCATION, + "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NO ACTION": TokenType.NO_ACTION, "NOT": TokenType.NOT, @@ -490,8 +506,8 @@ class Tokenizer(metaclass=_Tokenizer): "TEMPORARY": TokenType.TEMPORARY, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, + "TRAILING": TokenType.TRAILING, "TRUNCATE": TokenType.TRUNCATE, - "TRY_CAST": TokenType.TRY_CAST, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, "UNNEST": TokenType.UNNEST, @@ -626,14 +642,12 @@ class Tokenizer(metaclass=_Tokenizer): break white_space = self.WHITE_SPACE.get(self._char) - identifier_end = self.IDENTIFIERS.get(self._char) + identifier_end = self._IDENTIFIERS.get(self._char) if white_space: if white_space == TokenType.BREAK: self._col = 1 self._line += 1 - elif self._char == "0" and self._peek == "x": - self._scan_hex() elif self._char.isdigit(): self._scan_number() elif identifier_end: @@ -666,9 +680,7 @@ class Tokenizer(metaclass=_Tokenizer): text = self._text if text is None else text self.tokens.append(Token(token_type, text, self._line, self._col)) - if token_type in self.COMMANDS and ( - len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON - ): + if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): self._start = self._current while not self._end and self._peek != ";": self._advance() @@ -725,6 +737,8 @@ class Tokenizer(metaclass=_Tokenizer): if self._scan_string(word): return + if self._scan_numeric_string(word): + return if self._scan_comment(word): return @@ -732,10 +746,10 @@ class Tokenizer(metaclass=_Tokenizer): self._add(self.KEYWORDS[word.upper()]) def _scan_comment(self, comment_start): - if comment_start not in self.COMMENTS: + if comment_start not in self._COMMENTS: return False - comment_end = self.COMMENTS[comment_start] + comment_end = self._COMMENTS[comment_start] if comment_end: comment_end_size = len(comment_end) @@ -749,15 +763,18 @@ class Tokenizer(metaclass=_Tokenizer): return True def _scan_annotation(self): - while ( - not self._end - and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK - and self._peek != "," - ): + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",": self._advance() self._add(TokenType.ANNOTATION, self._text[1:]) def _scan_number(self): + if self._char == "0": + peek = self._peek.upper() + if peek == "B": + return self._scan_bits() + elif peek == "X": + return self._scan_hex() + decimal = False scientific = 0 @@ -788,57 +805,71 @@ class Tokenizer(metaclass=_Tokenizer): else: return self._add(TokenType.NUMBER) + def _scan_bits(self): + self._advance() + value = self._extract_value() + try: + self._add(TokenType.BIT_STRING, f"{int(value, 2)}") + except ValueError: + self._add(TokenType.IDENTIFIER) + def _scan_hex(self): self._advance() + value = self._extract_value() + try: + self._add(TokenType.HEX_STRING, f"{int(value, 16)}") + except ValueError: + self._add(TokenType.IDENTIFIER) + def _extract_value(self): while True: char = self._peek.strip() if char and char not in self.SINGLE_TOKENS: self._advance() else: break - try: - self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}") - except ValueError: - self._add(TokenType.IDENTIFIER) + + return self._text def _scan_string(self, quote): - quote_end = self.QUOTES.get(quote) + quote_end = self._QUOTES.get(quote) if quote_end is None: return False - text = "" self._advance(len(quote)) - quote_end_size = len(quote_end) - - while True: - if self._char == self.ESCAPE and self._peek == quote_end: - text += quote - self._advance(2) - else: - if self._chars(quote_end_size) == quote_end: - if quote_end_size > 1: - self._advance(quote_end_size - 1) - break - - if self._end: - raise RuntimeError( - f"Missing {quote} from {self._line}:{self._start}" - ) - text += self._char - self._advance() + text = self._extract_string(quote_end) text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text self._add(TokenType.STRING, text) return True + def _scan_numeric_string(self, string_start): + if string_start in self._HEX_STRINGS: + delimiters = self._HEX_STRINGS + token_type = TokenType.HEX_STRING + base = 16 + elif string_start in self._BIT_STRINGS: + delimiters = self._BIT_STRINGS + token_type = TokenType.BIT_STRING + base = 2 + else: + return False + + self._advance(len(string_start)) + string_end = delimiters.get(string_start) + text = self._extract_string(string_end) + + try: + self._add(token_type, f"{int(text, base)}") + except ValueError: + raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + return True + def _scan_identifier(self, identifier_end): while self._peek != identifier_end: if self._end: - raise RuntimeError( - f"Missing {identifier_end} from {self._line}:{self._start}" - ) + raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") self._advance() self._advance() self._add(TokenType.IDENTIFIER, self._text[1:-1]) @@ -851,3 +882,24 @@ class Tokenizer(metaclass=_Tokenizer): else: break self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR)) + + def _extract_string(self, delimiter): + text = "" + delim_size = len(delimiter) + + while True: + if self._char == self.ESCAPE and self._peek == delimiter: + text += delimiter + self._advance(2) + else: + if self._chars(delim_size) == delimiter: + if delim_size > 1: + self._advance(delim_size - 1) + break + + if self._end: + raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") + text += self._char + self._advance() + + return text diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index e7ccb8e..7fc71dd 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -12,9 +12,7 @@ def unalias_group(expression): """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { - e.alias: i - for i, e in enumerate(expression.parent.expressions, start=1) - if isinstance(e, exp.Alias) + e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias) } expression = expression.copy() -- cgit v1.2.3