diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-12-02 09:16:32 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-12-02 09:16:32 +0000 |
commit | b3c7fe6a73484a4d2177c30f951cd11a4916ed56 (patch) | |
tree | 7192898cb782bbb0b9b13bd8d6341fe4434f0f31 /sqlglot | |
parent | Releasing debian version 10.0.8-1. (diff) | |
download | sqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.tar.xz sqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.zip |
Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
28 files changed, 827 insertions, 389 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 50e2d9c..b027ac7 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -30,7 +30,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.0.8" +__version__ = "10.1.3" pretty = False diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 4550d65..5b44912 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression): def _returnsproperty_sql(self, expression): - value = expression.args.get("value") - if isinstance(value, exp.Schema): - value = f"{value.this} <{self.expressions(value)}>" + this = expression.this + if isinstance(this, exp.Schema): + this = f"{this.this} <{self.expressions(this)}>" else: - value = self.sql(value) - return f"RETURNS {value}" + this = self.sql(this) + return f"RETURNS {this}" def _create_sql(self, expression): @@ -142,6 +142,11 @@ class BigQuery(Dialect): ), } + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + } + FUNCTION_PARSERS.pop("TRIM") + NO_PAREN_FUNCTIONS = { **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, @@ -174,6 +179,7 @@ class BigQuery(Dialect): exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, + exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -200,9 +206,7 @@ class BigQuery(Dialect): exp.VolatilityProperty, } - WITH_PROPERTIES = { - exp.AnonymousProperty, - } + WITH_PROPERTIES = {exp.Property} EXPLICIT_UNION = True diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 332b4c1..cbed72e 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -21,14 +21,15 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "FINAL": TokenType.FINAL, + "ASOF": TokenType.ASOF, "DATETIME64": TokenType.DATETIME, - "INT8": TokenType.TINYINT, + "FINAL": TokenType.FINAL, + "FLOAT32": TokenType.FLOAT, + "FLOAT64": TokenType.DOUBLE, "INT16": TokenType.SMALLINT, "INT32": TokenType.INT, "INT64": TokenType.BIGINT, - "FLOAT32": TokenType.FLOAT, - "FLOAT64": TokenType.DOUBLE, + "INT8": TokenType.TINYINT, "TUPLE": TokenType.STRUCT, } @@ -38,6 +39,10 @@ class ClickHouse(Dialect): "MAP": parse_var_map, } + JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} + + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} + def _parse_table(self, schema=False): this = super()._parse_table(schema) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8c497ab..c87f8d8 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -289,19 +289,19 @@ def struct_extract_sql(self, expression): return f"{this}.{struct_key}" -def var_map_sql(self, expression): +def var_map_sql(self, expression, map_func_name="MAP"): keys = expression.args["keys"] values = expression.args["values"] if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): self.unsupported("Cannot convert array columns into map.") - return f"MAP({self.format_args(keys, values)})" + return f"{map_func_name}({self.format_args(keys, values)})" args = [] for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) - return f"MAP({self.format_args(*args)})" + return f"{map_func_name}({self.format_args(*args)})" def format_time_lambda(exp_class, dialect, default=None): @@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression): if has_schema and is_partitionable: expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) - value = prop and prop.args.get("value") - if prop and not isinstance(value, exp.Schema): + this = prop and prop.this + if prop and not isinstance(this, exp.Schema): schema = expression.this - columns = {v.name.upper() for v in value.expressions} + columns = {v.name.upper() for v in this.expressions} partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set( - "expressions", - [e for e in schema.expressions if e not in partitions], - ) - prop.replace( - exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) - ) + schema.set("expressions", [e for e in schema.expressions if e not in partitions]) + prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) expression.set("this", schema) return self.create_sql(expression) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index eb420aa..358eced 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -153,7 +153,7 @@ class Drill(Dialect): exp.If: if_sql, exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index cff7139..cbb39c2 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -61,9 +61,7 @@ def _array_sort(self, expression): def _property_sql(self, expression): - key = expression.name - value = self.sql(expression, "value") - return f"'{key}'={value}" + return f"'{expression.name}'={self.sql(expression, 'value')}" def _str_to_unix(self, expression): @@ -250,7 +248,7 @@ class Hive(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, **transforms.UNALIAS_GROUP, # type: ignore - exp.AnonymousProperty: _property_sql, + exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayConcat: rename_func("CONCAT"), @@ -262,7 +260,7 @@ class Hive(Dialect): exp.DateStrToDate: rename_func("TO_DATE"), exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", + exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}", exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, @@ -285,7 +283,7 @@ class Hive(Dialect): exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", + exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -298,11 +296,11 @@ class Hive(Dialect): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.NumberToStr: rename_func("FORMAT_NUMBER"), } - WITH_PROPERTIES = {exp.AnonymousProperty} + WITH_PROPERTIES = {exp.Property} ROOT_PROPERTIES = { exp.PartitionedByProperty, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 93a60f4..7627b6e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -453,6 +453,7 @@ class MySQL(Dialect): exp.CharacterSetProperty, exp.CollateProperty, exp.SchemaCommentProperty, + exp.LikeProperty, } WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 870d2b9..ceaf9ba 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,7 +1,7 @@ from __future__ import annotations -from sqlglot import exp, generator, tokens, transforms -from sqlglot.dialects.dialect import Dialect, no_ilike_sql +from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func from sqlglot.helper import csv from sqlglot.tokens import TokenType @@ -37,6 +37,12 @@ class Oracle(Dialect): "YYYY": "%Y", # 2015 } + class Parser(parser.Parser): + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "DECODE": exp.Matches.from_arg_list, + } + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -58,6 +64,7 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 4353164..1cb5025 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -74,6 +74,27 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" +def _string_agg_sql(self, expression): + expression = expression.copy() + separator = expression.args.get("separator") or exp.Literal.string(",") + + order = "" + this = expression.this + if isinstance(this, exp.Order): + if this.this: + this = this.this + this.pop() + order = self.sql(expression.this) # Order has a leading space + + return f"STRING_AGG({self.format_args(this, separator)}{order})" + + +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return f"{self.expressions(expression, flat=True)}[]" + return self.datatype_sql(expression) + + def _auto_increment_to_serial(expression): auto = expression.find(exp.AutoIncrementColumnConstraint) @@ -191,25 +212,27 @@ class Postgres(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, - "BY DEFAULT": TokenType.BY_DEFAULT, - "IDENTITY": TokenType.IDENTITY, - "GENERATED": TokenType.GENERATED, - "DOUBLE PRECISION": TokenType.DOUBLE, - "BIGSERIAL": TokenType.BIGSERIAL, - "SERIAL": TokenType.SERIAL, - "SMALLSERIAL": TokenType.SMALLSERIAL, - "UUID": TokenType.UUID, - "TEMP": TokenType.TEMPORARY, - "BEGIN TRANSACTION": TokenType.BEGIN, "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, + "BIGSERIAL": TokenType.BIGSERIAL, + "BY DEFAULT": TokenType.BY_DEFAULT, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, + "DOUBLE PRECISION": TokenType.DOUBLE, + "GENERATED": TokenType.GENERATED, + "GRANT": TokenType.COMMAND, + "HSTORE": TokenType.HSTORE, + "IDENTITY": TokenType.IDENTITY, + "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, "REVOKE": TokenType.COMMAND, - "GRANT": TokenType.COMMAND, + "SERIAL": TokenType.SERIAL, + "SMALLSERIAL": TokenType.SMALLSERIAL, + "TEMP": TokenType.TEMPORARY, + "UUID": TokenType.UUID, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } @@ -265,4 +288,7 @@ class Postgres(Dialect): exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", + exp.DataType: _datatype_sql, + exp.GroupConcat: _string_agg_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9d5cc11..1a09037 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -171,16 +171,7 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") - ROOT_PROPERTIES = { - exp.SchemaCommentProperty, - } - - WITH_PROPERTIES = { - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.AnonymousProperty, - exp.TableFormatProperty, - } + ROOT_PROPERTIES = {exp.SchemaCommentProperty} TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -231,7 +222,8 @@ class Presto(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", + exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.TimeStrToDate: _date_parse_sql, exp.TimeStrToTime: _date_parse_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a9b12fb..cd50979 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sqlglot import exp +from sqlglot import exp, transforms from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -18,12 +18,14 @@ class Redshift(Postgres): KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore + "COPY": TokenType.COMMAND, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, + "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, "SIMILAR TO": TokenType.SIMILAR_TO, } @@ -35,3 +37,17 @@ class Redshift(Postgres): exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } + + ROOT_PROPERTIES = { + exp.DistKeyProperty, + exp.SortKeyProperty, + exp.DistStyleProperty, + } + + TRANSFORMS = { + **Postgres.Generator.TRANSFORMS, # type: ignore + **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", + exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.DistStyleProperty: lambda self, e: self.naked_property(e), + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a96bd80..46155ff 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, inline_array_sql, rename_func, + var_map_sql, ) from sqlglot.expressions import Literal from sqlglot.helper import seq_get @@ -100,6 +101,14 @@ def _parse_date_part(self): return self.expression(exp.Extract, this=this, expression=expression) +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return "ARRAY" + elif expression.this == exp.DataType.Type.MAP: + return "OBJECT" + return self.datatype_sql(expression) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -142,6 +151,8 @@ class Snowflake(Dialect): "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, + "DECODE": exp.Matches.from_arg_list, + "OBJECT_CONSTRUCT": parser.parse_var_map, } FUNCTION_PARSERS = { @@ -195,16 +206,20 @@ class Snowflake(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), + exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Matches: rename_func("DECODE"), + exp.StrPosition: rename_func("POSITION"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Array: inline_array_sql, - exp.StrPosition: rename_func("POSITION"), - exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.UnixToTime: _unix_to_time_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 4e404b8..16083d1 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -98,7 +98,7 @@ class Spark(Hive): TRANSFORMS = { **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", + exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 87b98a5..bbb752b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType +# https://www.sqlite.org/lang_aggfunc.html#group_concat +def _group_concat_sql(self, expression): + this = expression.this + distinct = expression.find(exp.Distinct) + if distinct: + this = distinct.expressions[0] + distinct = "DISTINCT " + + if isinstance(expression.this, exp.Order): + self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") + if expression.this.this and not distinct: + this = expression.this.this + + separator = expression.args.get("separator") + return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -62,6 +79,7 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, + exp.GroupConcat: _group_concat_sql, } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index d3b83de..07ce38b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = { "mm": "%B", "m": "%B", } + DATE_DELTA_INTERVAL = { "year": "year", "yyyy": "year", @@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = { DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") + # N = Numeric, C=Currency TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} -def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): +def _format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( this=seq_get(args, 1), @@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def parse_format(args): +def _parse_format(args): fmt = seq_get(args, 1) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) if number_fmt: @@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e): return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" -def generate_format_sql(self, e): +def _format_sql(self, e): fmt = ( e.args["format"] if isinstance(e, exp.NumberToStr) @@ -87,6 +89,28 @@ def generate_format_sql(self, e): return f"FORMAT({self.format_args(e.this, fmt)})" +def _string_agg_sql(self, e): + e = e.copy() + + this = e.this + distinct = e.find(exp.Distinct) + if distinct: + # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression + self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") + this = distinct.expressions[0] + distinct.pop() + + order = "" + if isinstance(e.this, exp.Order): + if e.this.this: + this = e.this.this + e.this.this.pop() + order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space + + separator = e.args.get("separator") or exp.Literal.string(",") + return f"STRING_AGG({self.format_args(this, separator)}){order}" + + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" @@ -228,14 +252,14 @@ class TSQL(Dialect): "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), - "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), - "DATEPART": tsql_format_time_lambda(exp.TimeToStr), + "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), + "DATEPART": _format_time_lambda(exp.TimeToStr), "GETDATE": exp.CurrentDate.from_arg_list, "IIF": exp.If.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "FORMAT": parse_format, + "FORMAT": _parse_format, } VAR_LENGTH_DATATYPES = { @@ -298,6 +322,7 @@ class TSQL(Dialect): exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.If: rename_func("IIF"), - exp.NumberToStr: generate_format_sql, - exp.TimeToStr: generate_format_sql, + exp.NumberToStr: _format_sql, + exp.TimeToStr: _format_sql, + exp.GroupConcat: _string_agg_sql, } diff --git a/sqlglot/errors.py b/sqlglot/errors.py index 23a08bd..b5ef5ad 100644 --- a/sqlglot/errors.py +++ b/sqlglot/errors.py @@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError): class ParseError(SqlglotError): - pass + def __init__( + self, + message: str, + errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None, + ): + super().__init__(message) + self.errors = errors or [] + + @classmethod + def new( + cls, + message: str, + description: t.Optional[str] = None, + line: t.Optional[int] = None, + col: t.Optional[int] = None, + start_context: t.Optional[str] = None, + highlight: t.Optional[str] = None, + end_context: t.Optional[str] = None, + into_expression: t.Optional[str] = None, + ) -> ParseError: + return cls( + message, + [ + { + "description": description, + "line": line, + "col": col, + "start_context": start_context, + "highlight": highlight, + "end_context": end_context, + "into_expression": into_expression, + } + ], + ) class TokenError(SqlglotError): @@ -41,9 +74,13 @@ class ExecuteError(SqlglotError): pass -def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: +def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str: msg = [str(e) for e in errors[:maximum]] remaining = len(errors) - maximum if remaining > 0: msg.append(f"... and {remaining} more") return "\n\n".join(msg) + + +def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]: + return [e_dict for error in errors for e_dict in error.errors] diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index ed80cc9..e6cfcdd 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -122,7 +122,6 @@ def interval(this, unit): ENV = { - "__builtins__": {}, "exp": exp, # aggs "SUM": filter_nulls(sum), diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index cb2543c..908b80a 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -115,6 +115,9 @@ class PythonExecutor: sink = self.table(context.columns) for reader in table_iter: + if len(sink) >= step.limit: + break + if condition and not context.eval(condition): continue @@ -123,9 +126,6 @@ class PythonExecutor: else: sink.append(reader.row) - if len(sink) >= step.limit: - break - return self.context({step.name: sink}) def static(self): @@ -288,21 +288,32 @@ class PythonExecutor: end = 1 length = len(context.table) table = self.table(list(step.group) + step.aggregations) + condition = self.generate(step.condition) - for i in range(length): - context.set_index(i) - key = context.eval_tuple(group_by) - group = key if group is None else group - end += 1 - if key != group: - context.set_range(start, end - 2) - table.append(group + context.eval_tuple(aggregations)) - group = key - start = end - 2 - if i == length - 1: - context.set_range(start, end - 1) + def add_row(): + if not condition or context.eval(condition): table.append(group + context.eval_tuple(aggregations)) + if length: + for i in range(length): + context.set_index(i) + key = context.eval_tuple(group_by) + group = key if group is None else group + end += 1 + if key != group: + context.set_range(start, end - 2) + add_row() + group = key + start = end - 2 + if len(table.rows) >= step.limit: + break + if i == length - 1: + context.set_range(start, end - 1) + add_row() + elif step.limit > 0: + context.set_range(0, 0) + table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations)) + context = self.context({step.name: table, **{name: table for name in context.tables}}) if step.projections: @@ -311,11 +322,9 @@ class PythonExecutor: def sort(self, step, context): projections = self.generate_tuple(step.projections) - projection_columns = [p.alias_or_name for p in step.projections] all_columns = list(context.columns) + projection_columns sink = self.table(all_columns) - for reader, ctx in context: sink.append(reader.row + ctx.eval_tuple(projections)) @@ -401,8 +410,9 @@ class Python(Dialect): exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", - exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", + exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", exp.Is: lambda self, e: self.binary(e, "is"), exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index beafca8..96b32f1 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comment") + __slots__ = ("args", "parent", "arg_key", "type", "comments") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None - self.comment = None + self.comments = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -88,19 +88,6 @@ class Expression(metaclass=_Expression): return field.this return "" - def find_comment(self, key: str) -> str: - """ - Finds the comment that is attached to a specified child node. - - Args: - key: the key of the target child node (e.g. "this", "expression", etc). - - Returns: - The comment attached to the child node, or the empty string, if it doesn't exist. - """ - field = self.args.get(key) - return field.comment if isinstance(field, Expression) else "" - @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -137,7 +124,7 @@ class Expression(metaclass=_Expression): def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) - copy.comment = self.comment + copy.comments = self.comments copy.type = self.type return copy @@ -369,7 +356,7 @@ class Expression(metaclass=_Expression): ) for k, vs in self.args.items() } - args["comment"] = self.comment + args["comments"] = self.comments args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} @@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind): class PrimaryKeyColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"desc": False} class UniqueColumnConstraint(ColumnConstraintKind): @@ -819,6 +806,12 @@ class Unique(Expression): arg_types = {"expressions": True} +# https://www.postgresql.org/docs/9.1/sql-selectinto.html +# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples +class Into(Expression): + arg_types = {"this": True, "temporary": False, "unlogged": False} + + class From(Expression): arg_types = {"expressions": True} @@ -1065,67 +1058,67 @@ class Property(Expression): class TableFormatProperty(Property): - pass + arg_types = {"this": True} class PartitionedByProperty(Property): - pass + arg_types = {"this": True} class FileFormatProperty(Property): - pass + arg_types = {"this": True} class DistKeyProperty(Property): - pass + arg_types = {"this": True} class SortKeyProperty(Property): - pass + arg_types = {"this": True, "compound": False} class DistStyleProperty(Property): - pass + arg_types = {"this": True} + + +class LikeProperty(Property): + arg_types = {"this": True, "expressions": False} class LocationProperty(Property): - pass + arg_types = {"this": True} class EngineProperty(Property): - pass + arg_types = {"this": True} class AutoIncrementProperty(Property): - pass + arg_types = {"this": True} class CharacterSetProperty(Property): - arg_types = {"this": True, "value": True, "default": True} + arg_types = {"this": True, "default": True} class CollateProperty(Property): - pass + arg_types = {"this": True} class SchemaCommentProperty(Property): - pass - - -class AnonymousProperty(Property): - pass + arg_types = {"this": True} class ReturnsProperty(Property): - arg_types = {"this": True, "value": True, "is_table": False} + arg_types = {"this": True, "is_table": False} class LanguageProperty(Property): - pass + arg_types = {"this": True} class ExecuteAsProperty(Property): - pass + arg_types = {"this": True} class VolatilityProperty(Property): @@ -1135,27 +1128,36 @@ class VolatilityProperty(Property): class Properties(Expression): arg_types = {"expressions": True} - PROPERTY_KEY_MAPPING = { + NAME_TO_PROPERTY = { "AUTO_INCREMENT": AutoIncrementProperty, - "CHARACTER_SET": CharacterSetProperty, + "CHARACTER SET": CharacterSetProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, + "DISTKEY": DistKeyProperty, + "DISTSTYLE": DistStyleProperty, "ENGINE": EngineProperty, + "EXECUTE AS": ExecuteAsProperty, "FORMAT": FileFormatProperty, + "LANGUAGE": LanguageProperty, "LOCATION": LocationProperty, "PARTITIONED_BY": PartitionedByProperty, - "TABLE_FORMAT": TableFormatProperty, - "DISTKEY": DistKeyProperty, - "DISTSTYLE": DistStyleProperty, + "RETURNS": ReturnsProperty, "SORTKEY": SortKeyProperty, + "TABLE_FORMAT": TableFormatProperty, } + PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + @classmethod def from_dict(cls, properties_dict) -> Properties: expressions = [] for key, value in properties_dict.items(): - property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) - expressions.append(property_cls(this=Literal.string(key), value=convert(value))) + property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) + if property_cls: + expressions.append(property_cls(this=convert(value))) + else: + expressions.append(Property(this=Literal.string(key), value=convert(value))) + return cls(expressions=expressions) @@ -1383,6 +1385,7 @@ class Select(Subqueryable): "expressions": False, "hint": False, "distinct": False, + "into": False, "from": False, **QUERY_MODIFIERS, } @@ -2015,6 +2018,7 @@ class DataType(Expression): DECIMAL = auto() BOOLEAN = auto() JSON = auto() + JSONB = auto() INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() @@ -2029,6 +2033,7 @@ class DataType(Expression): STRUCT = auto() NULLABLE = auto() HLLSKETCH = auto() + HSTORE = auto() SUPER = auto() SERIAL = auto() SMALLSERIAL = auto() @@ -2109,7 +2114,7 @@ class Transaction(Command): class Commit(Command): - arg_types = {} # type: ignore + arg_types = {"chain": False} class Rollback(Command): @@ -2442,7 +2447,7 @@ class ArrayFilter(Func): class ArraySize(Func): - pass + arg_types = {"this": True, "expression": False} class ArraySort(Func): @@ -2726,6 +2731,16 @@ class VarMap(Func): is_var_len_args = True +class Matches(Func): + """Oracle/Snowflake decode. + https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm + Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else) + """ + + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + class Max(AggFunc): pass @@ -2785,6 +2800,10 @@ class Round(Func): arg_types = {"this": True, "decimals": False} +class RowNumber(Func): + arg_types: t.Dict[str, t.Any] = {} + + class SafeDivide(Func): arg_types = {"this": True, "expression": True} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ffb34eb..47774fc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,19 +1,16 @@ from __future__ import annotations import logging -import re import typing as t from sqlglot import exp -from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors +from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv from sqlglot.time import format_time from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") -NEWLINE_RE = re.compile("\r\n?|\n") - class Generator: """ @@ -58,11 +55,11 @@ class Generator: """ TRANSFORMS = { - exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -97,16 +94,17 @@ class Generator: exp.DistStyleProperty, exp.DistKeyProperty, exp.SortKeyProperty, + exp.LikeProperty, } WITH_PROPERTIES = { - exp.AnonymousProperty, + exp.Property, exp.FileFormatProperty, exp.PartitionedByProperty, exp.TableFormatProperty, } - WITH_SEPARATED_COMMENTS = (exp.Select,) + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) __slots__ = ( "time_mapping", @@ -211,7 +209,7 @@ class Generator: for msg in self.unsupported_messages: logger.warning(msg) elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) + raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) return sql @@ -226,25 +224,24 @@ class Generator: def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" - def maybe_comment(self, sql, expression, single_line=False): - comment = expression.comment if self._comments else None - - if not comment: - return sql - + def pad_comment(self, comment): comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment + return comment - if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"/*{comment}*/{self.sep()}{sql}" + def maybe_comment(self, sql, expression): + comments = expression.comments if self._comments else None - if not self.pretty: - return f"{sql} /*{comment}*/" + if not comments: + return sql + + sep = "\n" if self.pretty else " " + comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) - if not NEWLINE_RE.search(comment): - return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"{comments}{self.sep()}{sql}" - return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/" + return f"{sql} {comments}" def wrap(self, expression): this_sql = self.indent( @@ -387,8 +384,11 @@ class Generator: def notnullcolumnconstraint_sql(self, _): return "NOT NULL" - def primarykeycolumnconstraint_sql(self, _): - return "PRIMARY KEY" + def primarykeycolumnconstraint_sql(self, expression): + desc = expression.args.get("desc") + if desc is not None: + return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" + return f"PRIMARY KEY" def uniquecolumnconstraint_sql(self, _): return "UNIQUE" @@ -546,36 +546,33 @@ class Generator: def root_properties(self, properties): if properties.expressions: - return self.sep() + self.expressions( - properties, - indent=False, - sep=" ", - ) + return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" def properties(self, properties, prefix="", sep=", "): if properties.expressions: - expressions = self.expressions( - properties, - sep=sep, - indent=False, - ) + expressions = self.expressions(properties, sep=sep, indent=False) return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" return "" def with_properties(self, properties): - return self.properties( - properties, - prefix="WITH", - ) + return self.properties(properties, prefix="WITH") def property_sql(self, expression): - if isinstance(expression.this, exp.Literal): - key = expression.this.this - else: - key = expression.name - value = self.sql(expression, "value") - return f"{key}={value}" + property_cls = expression.__class__ + if property_cls == exp.Property: + return f"{expression.name}={self.sql(expression, 'value')}" + + property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) + if not property_name: + self.unsupported(f"Unsupported property {property_name}") + + return f"{property_name}={self.sql(expression, 'this')}" + + def likeproperty_sql(self, expression): + options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) + options = f" {options}" if options else "" + return f"LIKE {self.sql(expression, 'this')}{options}" def insert_sql(self, expression): overwrite = expression.args.get("overwrite") @@ -700,6 +697,11 @@ class Generator: def var_sql(self, expression): return self.sql(expression, "this") + def into_sql(self, expression): + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" + return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" + def from_sql(self, expression): expressions = self.expressions(expression, flat=True) return f"{self.seg('FROM')} {expressions}" @@ -883,6 +885,7 @@ class Generator: sql = self.query_modifiers( expression, f"SELECT{hint}{distinct}{expressions}", + self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) return self.prepend_ctes(expression, sql) @@ -1061,6 +1064,11 @@ class Generator: else: return f"TRIM({target})" + def concat_sql(self, expression): + if len(expression.expressions) == 1: + return self.sql(expression.expressions[0]) + return self.function_fallback_sql(expression) + def check_sql(self, expression): this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -1125,7 +1133,10 @@ class Generator: return self.prepend_ctes(expression, sql) def neg_sql(self, expression): - return f"-{self.sql(expression, 'this')}" + # This makes sure we don't convert "- - 5" to "--5", which is a comment + this_sql = self.sql(expression, "this") + sep = " " if this_sql[0] == "-" else "" + return f"-{sep}{this_sql}" def not_sql(self, expression): return f"NOT {self.sql(expression, 'this')}" @@ -1191,8 +1202,12 @@ class Generator: def transaction_sql(self, *_): return "BEGIN" - def commit_sql(self, *_): - return "COMMIT" + def commit_sql(self, expression): + chain = expression.args.get("chain") + if chain is not None: + chain = " AND CHAIN" if chain else " AND NO CHAIN" + + return f"COMMIT{chain or ''}" def rollback_sql(self, expression): savepoint = expression.args.get("savepoint") @@ -1334,15 +1349,15 @@ class Generator: result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) - comment = self.maybe_comment("", e, single_line=True) + comments = self.maybe_comment("", e) if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") else: - result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sqls, skip_first=False) if indent else result_sqls @@ -1354,7 +1369,10 @@ class Generator: return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" def naked_property(self, expression): - return f"{expression.name} {self.sql(expression, 'value')}" + property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) + if not property_name: + self.unsupported(f"Unsupported property {expression.__class__.__name__}") + return f"{property_name} {self.sql(expression, 'this')}" def set_operation(self, expression, op): this = self.sql(expression, "this") diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 8704e90..39e252c 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -68,6 +68,9 @@ def eliminate_subqueries(expression): for cte_scope in root.cte_scopes: # Append all the new CTEs from this existing CTE for scope in cte_scope.traverse(): + if scope is cte_scope: + # Don't try to eliminate this CTE itself + continue new_cte = _eliminate(scope, existing_ctes, taken) if new_cte: new_ctes.append(new_cte) @@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) + if scope.is_cte: + return _eliminate_cte(scope, existing_ctes, taken) + def _eliminate_union(scope, existing_ctes, taken): duplicate_cte_alias = existing_ctes.get(scope.expression) @@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + table = exp.alias_(exp.table_(name), alias=parent.alias or name) + parent.replace(table) + + return cte + + +def _eliminate_cte(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + with_ = parent.parent + parent.pop() + if not with_.expressions: + with_.pop() + + # Rename references to this CTE + for child_scope in scope.parent.traverse(): + for table, source in child_scope.selected_sources.values(): + if source is scope: + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + table.replace(new_table) + + return cte + + +def _new_cte(scope, existing_ctes, taken): + """ + Returns: + tuple of (name, cte) + where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. + If this CTE duplicates an existing CTE, `cte` will be None. + """ duplicate_cte_alias = existing_ctes.get(scope.expression) parent = scope.expression.parent - name = alias = parent.alias + name = parent.alias - if not alias: - name = alias = find_new_name(taken=taken, base="cte") + if not name: + name = find_new_name(taken=taken, base="cte") if duplicate_cte_alias: name = duplicate_cte_alias - elif taken.get(alias): - name = find_new_name(taken=taken, base=alias) + elif taken.get(name): + name = find_new_name(taken=taken, base=name) taken[name] = scope - table = exp.alias_(exp.table_(name), alias=alias) - parent.replace(table) - if not duplicate_cte_alias: existing_ctes[scope.expression] = name - return exp.CTE( + cte = exp.CTE( this=scope.expression, alias=exp.TableAlias(this=exp.to_identifier(name)), ) + else: + cte = None + return name, cte diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py new file mode 100644 index 0000000..1cc76cf --- /dev/null +++ b/sqlglot/optimizer/lower_identities.py @@ -0,0 +1,92 @@ +from sqlglot import exp +from sqlglot.helper import ensure_collection + + +def lower_identities(expression): + """ + Convert all unquoted identifiers to lower case. + + Assuming the schema is all lower case, this essentially makes identifiers case-insensitive. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> lower_identities(expression).sql() + 'SELECT bar.a AS A FROM "Foo".bar' + + Args: + expression (sqlglot.Expression): expression to quote + Returns: + sqlglot.Expression: quoted expression + """ + # We need to leave the output aliases unchanged, so the selects need special handling + _lower_selects(expression) + + # These clauses can reference output aliases and also need special handling + _lower_order(expression) + _lower_having(expression) + + # We've already handled these args, so don't traverse into them + traversed = {"expressions", "order", "having"} + + if isinstance(expression, exp.Subquery): + # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1 + lower_identities(expression.this) + traversed |= {"this"} + + if isinstance(expression, exp.Union): + # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X + lower_identities(expression.left) + lower_identities(expression.right) + traversed |= {"this", "expression"} + + for k, v in expression.args.items(): + if k in traversed: + continue + + for child in ensure_collection(v): + if isinstance(child, exp.Expression): + child.transform(_lower, copy=False) + + return expression + + +def _lower_selects(expression): + for e in expression.expressions: + # Leave output aliases as-is + e.unalias().transform(_lower, copy=False) + + +def _lower_order(expression): + order = expression.args.get("order") + + if not order: + return + + output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)} + + for ordered in order.expressions: + # Don't lower references to output aliases + if not ( + isinstance(ordered.this, exp.Column) + and not ordered.this.table + and ordered.this.name in output_aliases + ): + ordered.transform(_lower, copy=False) + + +def _lower_having(expression): + having = expression.args.get("having") + + if not having: + return + + # Don't lower references to output aliases + for agg in having.find_all(exp.AggFunc): + agg.transform(_lower, copy=False) + + +def _lower(node): + if isinstance(node, exp.Identifier) and not node.quoted: + node.set("this", node.this.lower()) + return node diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d0e38cd..6819717 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins @@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries RULES = ( + lower_identities, qualify_tables, isolate_table_selects, qualify_columns, diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index dbd680b..2046917 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,16 +1,15 @@ import itertools from sqlglot import exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import ScopeType, traverse_scope def unnest_subqueries(expression): """ Rewrite sqlglot AST to convert some predicates with subqueries into joins. - Convert the subquery into a group by so it is not a many to many left join. - Unnesting can only occur if the subquery does not have LIMIT or OFFSET. - Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + Convert scalar subqueries into cross joins. + Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. Example: >>> import sqlglot @@ -29,21 +28,43 @@ def unnest_subqueries(expression): for scope in traverse_scope(expression): select = scope.expression parent = select.parent_select + if not parent: + continue if scope.external_columns: decorrelate(select, parent, scope.external_columns, sequence) - else: + elif scope.scope_type == ScopeType.SUBQUERY: unnest(select, parent, sequence) return expression def unnest(select, parent_select, sequence): - predicate = select.find_ancestor(exp.In, exp.Any) + if len(select.selects) > 1: + return + + predicate = select.find_ancestor(exp.Condition) + alias = _alias(sequence) if not predicate or parent_select is not predicate.parent_select: return - if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + # this subquery returns a scalar and can just be converted to a cross join + if not isinstance(predicate, (exp.In, exp.Any)): + having = predicate.find_ancestor(exp.Having) + column = exp.column(select.selects[0].alias_or_name, alias) + if having and having.parent_select is parent_select: + column = exp.Max(this=column) + _replace(select.parent, column) + + parent_select.join( + select, + join_type="CROSS", + join_alias=alias, + copy=False, + ) + return + + if select.find(exp.Limit, exp.Offset): return if isinstance(predicate, exp.Any): @@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence): column = _other_operand(predicate) value = select.selects[0] - alias = _alias(sequence) on = exp.condition(f'{column} = "{alias}"."{value.alias}"') _replace(predicate, f"NOT {on.right} IS NULL") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5b93510..bdf0d2d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4,7 +4,7 @@ import logging import typing as t from sqlglot import exp -from sqlglot.errors import ErrorLevel, ParseError, concat_errors +from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -104,6 +104,7 @@ class Parser(metaclass=_Parser): TokenType.BINARY, TokenType.VARBINARY, TokenType.JSON, + TokenType.JSONB, TokenType.INTERVAL, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, @@ -115,6 +116,7 @@ class Parser(metaclass=_Parser): TokenType.GEOGRAPHY, TokenType.GEOMETRY, TokenType.HLLSKETCH, + TokenType.HSTORE, TokenType.SUPER, TokenType.SERIAL, TokenType.SMALLSERIAL, @@ -153,6 +155,7 @@ class Parser(metaclass=_Parser): TokenType.COLLATE, TokenType.COMMAND, TokenType.COMMIT, + TokenType.COMPOUND, TokenType.CONSTRAINT, TokenType.CURRENT_TIME, TokenType.DEFAULT, @@ -194,6 +197,7 @@ class Parser(metaclass=_Parser): TokenType.RANGE, TokenType.REFERENCES, TokenType.RETURNS, + TokenType.ROW, TokenType.ROWS, TokenType.SCHEMA, TokenType.SCHEMA_COMMENT, @@ -213,6 +217,7 @@ class Parser(metaclass=_Parser): TokenType.TRUE, TokenType.UNBOUNDED, TokenType.UNIQUE, + TokenType.UNLOGGED, TokenType.UNPIVOT, TokenType.PROPERTIES, TokenType.PROCEDURE, @@ -400,9 +405,17 @@ class Parser(metaclass=_Parser): TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), TokenType.BEGIN: lambda self: self._parse_transaction(), TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), } + UNARY_PARSERS = { + TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op + TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()), + TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()), + TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()), + } + PRIMARY_PARSERS = { TokenType.STRING: lambda self, token: self.expression( exp.Literal, this=token.text, is_string=True @@ -446,19 +459,20 @@ class Parser(metaclass=_Parser): } PROPERTY_PARSERS = { - TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), - TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), - TokenType.LOCATION: lambda self: self.expression( - exp.LocationProperty, - this=exp.Literal.string("LOCATION"), - value=self._parse_string(), + TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment( + exp.AutoIncrementProperty ), + TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), + TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty), TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), - TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), - TokenType.STORED: lambda self: self._parse_stored(), + TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment( + exp.SchemaCommentProperty + ), + TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty), TokenType.DISTKEY: lambda self: self._parse_distkey(), - TokenType.DISTSTYLE: lambda self: self._parse_diststyle(), + TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty), TokenType.SORTKEY: lambda self: self._parse_sortkey(), + TokenType.LIKE: lambda self: self._parse_create_like(), TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), @@ -468,7 +482,7 @@ class Parser(metaclass=_Parser): ), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), - TokenType.EXECUTE: lambda self: self._parse_execute_as(), + TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), TokenType.DETERMINISTIC: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), @@ -489,6 +503,7 @@ class Parser(metaclass=_Parser): ), TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), TokenType.UNIQUE: lambda self: self._parse_unique(), + TokenType.LIKE: lambda self: self._parse_create_like(), } NO_PAREN_FUNCTION_PARSERS = { @@ -505,6 +520,7 @@ class Parser(metaclass=_Parser): "TRIM": lambda self: self._parse_trim(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "TRY_CAST": lambda self: self._parse_cast(False), + "STRING_AGG": lambda self: self._parse_string_agg(), } QUERY_MODIFIER_PARSERS = { @@ -556,7 +572,7 @@ class Parser(metaclass=_Parser): "_curr", "_next", "_prev", - "_prev_comment", + "_prev_comments", "_show_trie", "_set_trie", ) @@ -589,7 +605,7 @@ class Parser(metaclass=_Parser): self._curr = None self._next = None self._prev = None - self._prev_comment = None + self._prev_comments = None def parse(self, raw_tokens, sql=None): """ @@ -608,6 +624,7 @@ class Parser(metaclass=_Parser): ) def parse_into(self, expression_types, raw_tokens, sql=None): + errors = [] for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: @@ -615,8 +632,12 @@ class Parser(metaclass=_Parser): try: return self._parse(parser, raw_tokens, sql) except ParseError as e: - error = e - raise ParseError(f"Failed to parse into {expression_types}") from error + e.errors[0]["into_expression"] = expression_type + errors.append(e) + raise ParseError( + f"Failed to parse into {expression_types}", + errors=merge_errors(errors), + ) from errors[-1] def _parse(self, parse_method, raw_tokens, sql=None): self.reset() @@ -650,7 +671,10 @@ class Parser(metaclass=_Parser): for error in self.errors: logger.error(str(error)) elif self.error_level == ErrorLevel.RAISE and self.errors: - raise ParseError(concat_errors(self.errors, self.max_errors)) + raise ParseError( + concat_messages(self.errors, self.max_errors), + errors=merge_errors(self.errors), + ) def raise_error(self, message, token=None): token = token or self._curr or self._prev or Token.string("") @@ -659,19 +683,27 @@ class Parser(metaclass=_Parser): start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] - error = ParseError( + error = ParseError.new( f"{message}. Line {token.line}, Col: {token.col}.\n" - f" {start_context}\033[4m{highlight}\033[0m{end_context}" + f" {start_context}\033[4m{highlight}\033[0m{end_context}", + description=message, + line=token.line, + col=token.col, + start_context=start_context, + highlight=highlight, + end_context=end_context, ) if self.error_level == ErrorLevel.IMMEDIATE: raise error self.errors.append(error) - def expression(self, exp_class, **kwargs): + def expression(self, exp_class, comments=None, **kwargs): instance = exp_class(**kwargs) - if self._prev_comment: - instance.comment = self._prev_comment - self._prev_comment = None + if self._prev_comments: + instance.comments = self._prev_comments + self._prev_comments = None + if comments: + instance.comments = comments self.validate_expression(instance) return instance @@ -714,10 +746,10 @@ class Parser(metaclass=_Parser): self._next = seq_get(self._tokens, self._index + 1) if self._index > 0: self._prev = self._tokens[self._index - 1] - self._prev_comment = self._prev.comment + self._prev_comments = self._prev.comments else: self._prev = None - self._prev_comment = None + self._prev_comments = None def _retreat(self, index): self._advance(index - self._index) @@ -768,7 +800,7 @@ class Parser(metaclass=_Parser): ) def _parse_create(self): - replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) + replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) transient = self._match(TokenType.TRANSIENT) unique = self._match(TokenType.UNIQUE) @@ -822,97 +854,57 @@ class Parser(metaclass=_Parser): def _parse_property(self): if self._match_set(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.token_type](self) + if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): return self._parse_character_set(True) + if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): + return self._parse_sortkey(compound=True) + if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): - key = self._parse_var().this + key = self._parse_var() self._match(TokenType.EQ) - - return self.expression( - exp.AnonymousProperty, - this=exp.Literal.string(key), - value=self._parse_column(), - ) + return self.expression(exp.Property, this=key, value=self._parse_column()) return None def _parse_property_assignment(self, exp_class): - prop = self._prev.text self._match(TokenType.EQ) - return self.expression(exp_class, this=prop, value=self._parse_var_or_string()) + self._match(TokenType.ALIAS) + return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number()) def _parse_partitioned_by(self): self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, - this=exp.Literal.string("PARTITIONED_BY"), - value=self._parse_schema() or self._parse_bracket(self._parse_field()), - ) - - def _parse_stored(self): - self._match(TokenType.ALIAS) - self._match(TokenType.EQ) - return self.expression( - exp.FileFormatProperty, - this=exp.Literal.string("FORMAT"), - value=exp.Literal.string(self._parse_var_or_string().name), + this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) def _parse_distkey(self): - self._match_l_paren() - this = exp.Literal.string("DISTKEY") - value = exp.Literal.string(self._parse_var().name) - self._match_r_paren() - return self.expression( - exp.DistKeyProperty, - this=this, - value=value, - ) + return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var)) - def _parse_sortkey(self): - self._match_l_paren() - this = exp.Literal.string("SORTKEY") - value = exp.Literal.string(self._parse_var().name) - self._match_r_paren() - return self.expression( - exp.SortKeyProperty, - this=this, - value=value, - ) - - def _parse_diststyle(self): - this = exp.Literal.string("DISTSTYLE") - value = exp.Literal.string(self._parse_var().name) - return self.expression( - exp.DistStyleProperty, - this=this, - value=value, - ) - - def _parse_auto_increment(self): - self._match(TokenType.EQ) - return self.expression( - exp.AutoIncrementProperty, - this=exp.Literal.string("AUTO_INCREMENT"), - value=self._parse_number(), - ) + def _parse_create_like(self): + table = self._parse_table(schema=True) + options = [] + while self._match_texts(("INCLUDING", "EXCLUDING")): + options.append( + self.expression( + exp.Property, + this=self._prev.text.upper(), + value=exp.Var(this=self._parse_id_var().this.upper()), + ) + ) + return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_schema_comment(self): - self._match(TokenType.EQ) + def _parse_sortkey(self, compound=False): return self.expression( - exp.SchemaCommentProperty, - this=exp.Literal.string("COMMENT"), - value=self._parse_string(), + exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound ) def _parse_character_set(self, default=False): self._match(TokenType.EQ) return self.expression( - exp.CharacterSetProperty, - this=exp.Literal.string("CHARACTER_SET"), - value=self._parse_var_or_string(), - default=default, + exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) def _parse_returns(self): @@ -931,20 +923,7 @@ class Parser(metaclass=_Parser): else: value = self._parse_types() - return self.expression( - exp.ReturnsProperty, - this=exp.Literal.string("RETURNS"), - value=value, - is_table=is_table, - ) - - def _parse_execute_as(self): - self._match(TokenType.ALIAS) - return self.expression( - exp.ExecuteAsProperty, - this=exp.Literal.string("EXECUTE AS"), - value=self._parse_var(), - ) + return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) def _parse_properties(self): properties = [] @@ -956,7 +935,7 @@ class Parser(metaclass=_Parser): properties.extend( self._parse_wrapped_csv( lambda: self.expression( - exp.AnonymousProperty, + exp.Property, this=self._parse_string(), value=self._match(TokenType.EQ) and self._parse_string(), ) @@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser): options = [] if self._match(TokenType.OPTIONS): - options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ) + self._match_l_paren() + k = self._parse_string() + self._match(TokenType.EQ) + v = self._parse_string() + options = [k, v] + self._match_r_paren() self._match(TokenType.ALIAS) return self.expression( @@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser): self.raise_error(f"{this.key} does not support CTE") this = cte elif self._match(TokenType.SELECT): - comment = self._prev_comment + comments = self._prev_comments hint = self._parse_hint() all_ = self._match(TokenType.ALL) @@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser): expressions=expressions, limit=limit, ) - this.comment = comment + this.comments = comments + + into = self._parse_into() + if into: + this.set("into", into) + from_ = self._parse_from() if from_: this.set("from", from_) + self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) @@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser): return self.expression(exp.Hint, expressions=hints) return None + def _parse_into(self): + if not self._match(TokenType.INTO): + return None + + temp = self._match(TokenType.TEMPORARY) + unlogged = self._match(TokenType.UNLOGGED) + self._match(TokenType.TABLE) + + return self.expression( + exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged + ) + def _parse_from(self): if not self._match(TokenType.FROM): return None - - return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) + return self.expression( + exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) + ) def _parse_lateral(self): outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) @@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser): def _parse_where(self, skip_where_token=False): if not skip_where_token and not self._match(TokenType.WHERE): return None - return self.expression(exp.Where, this=self._parse_conjunction()) + return self.expression( + exp.Where, comments=self._prev_comments, this=self._parse_conjunction() + ) def _parse_group(self, skip_group_by_token=False): if not skip_group_by_token and not self._match(TokenType.GROUP_BY): @@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser): return self._parse_tokens(self._parse_unary, self.FACTOR) def _parse_unary(self): - if self._match(TokenType.NOT): - return self.expression(exp.Not, this=self._parse_equality()) - if self._match(TokenType.TILDA): - return self.expression(exp.BitwiseNot, this=self._parse_unary()) - if self._match(TokenType.DASH): - return self.expression(exp.Neg, this=self._parse_unary()) + if self._match_set(self.UNARY_PARSERS): + return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) def _parse_type(self): @@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser): expressions = None maybe_func = False - if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - return exp.DataType( - this=exp.DataType.Type.ARRAY, - expressions=[exp.DataType.build(type_token.value)], - nested=True, - ) - - if self._match(TokenType.L_BRACKET): - self._retreat(index) - return None - if self._match(TokenType.L_PAREN): if is_struct: expressions = self._parse_csv(self._parse_struct_kwargs) @@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser): self._match_r_paren() maybe_func = True + if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + return exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(type_token.value, expressions=expressions)], + nested=True, + ) + + if self._match(TokenType.L_BRACKET): + self._retreat(index) + return None + if nested and self._match(TokenType.LT): if is_struct: expressions = self._parse_csv(self._parse_struct_kwargs) @@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser): return exp.Literal.number(f"0.{self._prev.text}") if self._match(TokenType.L_PAREN): - comment = self._prev_comment + comments = self._prev_comments query = self._parse_select() if query: @@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Tuple, expressions=expressions) else: this = self.expression(exp.Paren, this=this) - if comment: - this.comment = comment + if comments: + this.comments = comments return this return None @@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.SCHEMA_COMMENT): kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): - kind = exp.PrimaryKeyColumnConstraint() + desc = None + if self._match(TokenType.ASC) or self._match(TokenType.DESC): + desc = self._prev.token_type == TokenType.DESC + kind = exp.PrimaryKeyColumnConstraint(desc=desc) elif self._match(TokenType.UNIQUE): kind = exp.UniqueColumnConstraint() elif self._match(TokenType.GENERATED): @@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") - this.comment = self._prev_comment + this.comments = self._prev_comments return self._parse_bracket(this) def _parse_case(self): @@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_string_agg(self): + if self._match(TokenType.DISTINCT): + args = self._parse_csv(self._parse_conjunction) + expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) + else: + args = self._parse_csv(self._parse_conjunction) + expression = seq_get(args, 0) + + index = self._index + if not self._match(TokenType.R_PAREN): + # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) + order = self._parse_order(this=expression) + return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) + + # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]). + # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that + # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. + if not self._match(TokenType.WITHIN_GROUP): + self._retreat(index) + this = exp.GroupConcat.from_arg_list(args) + self.validate_expression(this, args) + return this + + self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller) + order = self._parse_order(this=expression) + return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) + def _parse_convert(self, strict): this = self._parse_column() if self._match(TokenType.USING): @@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser): items = [parse_result] if parse_result is not None else [] while self._match(sep): - if parse_result and self._prev_comment is not None: - parse_result.comment = self._prev_comment + if parse_result and self._prev_comments: + parse_result.comments = self._prev_comments parse_result = parse_method() if parse_result is not None: @@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser): while self._match_set(expressions): this = self.expression( - expressions[self._prev.token_type], this=this, expression=parse_method() + expressions[self._prev.token_type], + this=this, + comments=self._prev_comments, + expression=parse_method(), ) return this @@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) def _parse_commit_or_rollback(self): + chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser): self._match_text_seq("SAVEPOINT") savepoint = self._parse_id_var() + if self._match(TokenType.AND): + chain = not self._match_text_seq("NO") + self._match_text_seq("CHAIN") + if is_rollback: return self.expression(exp.Rollback, savepoint=savepoint) - return self.expression(exp.Commit) + return self.expression(exp.Commit, chain=chain) def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) @@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser): def _match_l_paren(self, expression=None): if not self._match(TokenType.L_PAREN): self.raise_error("Expecting (") - if expression and self._prev_comment: - expression.comment = self._prev_comment + if expression and self._prev_comments: + expression.comments = self._prev_comments def _match_r_paren(self, expression=None): if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") - if expression and self._prev_comment: - expression.comment = self._prev_comment + if expression and self._prev_comments: + expression.comments = self._prev_comments def _match_texts(self, texts): if self._curr and self._curr.text.upper() in texts: diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 51db2d4..4967231 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -130,18 +130,20 @@ class Step: aggregations = [] sequence = itertools.count() - for e in expression.expressions: - aggregation = e.find(exp.AggFunc) - - if aggregation: - projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) - aggregations.append(e) - for operand in aggregation.unnest_operands(): + def extract_agg_operands(expression): + for agg in expression.find_all(exp.AggFunc): + for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" operand.replace(exp.column(operands[operand], quoted=True)) + + for e in expression.expressions: + if e.find(exp.AggFunc): + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + aggregations.append(e) + extract_agg_operands(e) else: projections.append(e) @@ -156,6 +158,13 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name + + having = expression.args.get("having") + + if having: + extract_agg_operands(having) + aggregate.condition = having.this + aggregate.operands = tuple( alias(operand, alias_) for operand, alias_ in operands.items() ) @@ -172,11 +181,6 @@ class Step: aggregate.add_dependency(step) step = aggregate - having = expression.args.get("having") - - if having: - step.condition = having.this - order = expression.args.get("order") if order: @@ -188,6 +192,17 @@ class Step: step.projections = projections + if isinstance(expression, exp.Select) and expression.args.get("distinct"): + distinct = Aggregate() + distinct.source = step.name + distinct.name = step.name + distinct.group = { + e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name) + for e in projections or expression.expressions + } + distinct.add_dependency(step) + step = distinct + limit = expression.args.get("limit") if limit: @@ -231,6 +246,9 @@ class Step: if self.condition: lines.append(f"{nested}Condition: {self.condition.sql()}") + if self.limit is not math.inf: + lines.append(f"{nested}Limit: {self.limit}") + if self.dependencies: lines.append(f"{nested}Dependencies:") for dependency in self.dependencies: @@ -258,12 +276,7 @@ class Scan(Step): cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None ) -> Step: table = expression - alias_ = expression.alias - - if not alias_: - raise UnsupportedError( - "Tables/Subqueries must be aliased. Run it through the optimizer" - ) + alias_ = expression.alias_or_name if isinstance(expression, exp.Subquery): table = expression.this @@ -338,6 +351,9 @@ class Aggregate(Step): lines.append(f"{indent}Group:") for expression in self.group.values(): lines.append(f"{indent} - {expression.sql()}") + if self.condition: + lines.append(f"{indent}Having:") + lines.append(f"{indent} - {self.condition.sql()}") if self.operands: lines.append(f"{indent}Operands:") for expression in self.operands: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index ec8cd91..8a7a38e 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -81,6 +81,7 @@ class TokenType(AutoName): BINARY = auto() VARBINARY = auto() JSON = auto() + JSONB = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -91,6 +92,7 @@ class TokenType(AutoName): NULLABLE = auto() GEOMETRY = auto() HLLSKETCH = auto() + HSTORE = auto() SUPER = auto() SERIAL = auto() SMALLSERIAL = auto() @@ -113,6 +115,7 @@ class TokenType(AutoName): APPLY = auto() ARRAY = auto() ASC = auto() + ASOF = auto() AT_TIME_ZONE = auto() AUTO_INCREMENT = auto() BEGIN = auto() @@ -130,6 +133,7 @@ class TokenType(AutoName): COMMAND = auto() COMMENT = auto() COMMIT = auto() + COMPOUND = auto() CONSTRAINT = auto() CREATE = auto() CROSS = auto() @@ -271,6 +275,7 @@ class TokenType(AutoName): UNBOUNDED = auto() UNCACHE = auto() UNION = auto() + UNLOGGED = auto() UNNEST = auto() UNPIVOT = auto() UPDATE = auto() @@ -291,7 +296,7 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col", "comment") + __slots__ = ("token_type", "text", "line", "col", "comments") @classmethod def number(cls, number: int) -> Token: @@ -319,13 +324,13 @@ class Token: text: str, line: int = 1, col: int = 1, - comment: t.Optional[str] = None, + comments: t.List[str] = [], ) -> None: self.token_type = token_type self.text = text self.line = line self.col = max(col - len(text), 1) - self.comment = comment + self.comments = comments def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) @@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer): "COLLATE": TokenType.COLLATE, "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, + "COMPOUND": TokenType.COMPOUND, "CONSTRAINT": TokenType.CONSTRAINT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, @@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer): "TRAILING": TokenType.TRAILING, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, - "UNPIVOT": TokenType.UNPIVOT, + "UNLOGGED": TokenType.UNLOGGED, "UNNEST": TokenType.UNNEST, + "UNPIVOT": TokenType.UNPIVOT, "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, @@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer): "_current", "_line", "_col", - "_comment", + "_comments", "_char", "_end", "_peek", "_prev_token_line", - "_prev_token_comment", + "_prev_token_comments", "_prev_token_type", "_replace_backslash", ) @@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer): self._current = 0 self._line = 1 self._col = 1 - self._comment = None + self._comments: t.List[str] = [] self._char = None self._end = None self._peek = None self._prev_token_line = -1 - self._prev_token_comment = None + self._prev_token_comments: t.List[str] = [] self._prev_token_type = None def tokenize(self, sql: str) -> t.List[Token]: @@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer): def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line - self._prev_token_comment = self._comment + self._prev_token_comments = self._comments self._prev_token_type = token_type # type: ignore self.tokens.append( Token( @@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer): self._text if text is None else text, self._line, self._col, - self._comment, + self._comments, ) ) - self._comment = None + self._comments = [] if token_type in self.COMMANDS and ( len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON @@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer): while not self._end and self._chars(comment_end_size) != comment_end: self._advance() - self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore + self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore self._advance(comment_end_size - 1) else: while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore self._advance() - self._comment = self._text[comment_start_size:] # type: ignore - - # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both - # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one. + self._comments.append(self._text[comment_start_size:]) # type: ignore + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. + # Multiple consecutive comments are preserved by appending them to the current comments list. if comment_start_line == self._prev_token_line: - if self._prev_token_comment is None: - self.tokens[-1].comment = self._comment - self._prev_token_comment = self._comment - - self._comment = None + self.tokens[-1].comments.extend(self._comments) + self._comments = [] return True diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 412b881..99949a1 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -2,6 +2,8 @@ from __future__ import annotations import typing as t +from sqlglot.helper import find_new_name + if t.TYPE_CHECKING: from sqlglot.generator import Generator @@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression: return expression +def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: + """ + Convert SELECT DISTINCT ON statements to a subquery with a window function. + + This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. + + Args: + expression: the expression that will be transformed. + + Returns: + The transformed expression. + """ + if ( + isinstance(expression, exp.Select) + and expression.args.get("distinct") + and expression.args["distinct"].args.get("on") + and isinstance(expression.args["distinct"].args["on"], exp.Tuple) + ): + distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions] + outer_selects = [e.copy() for e in expression.expressions] + nested = expression.copy() + nested.args["distinct"].pop() + row_number = find_new_name(expression.named_selects, "_row_number") + window = exp.Window( + this=exp.RowNumber(), + partition_by=distinct_cols, + ) + order = nested.args.get("order") + if order: + window.set("order", order.copy()) + order.pop() + window = exp.alias_(window, row_number) + nested.select(window, copy=False) + return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1') + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], to_sql: t.Callable[[Generator, exp.Expression], str], @@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable: UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} +ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} |