summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:32 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:32 +0000
commitb3c7fe6a73484a4d2177c30f951cd11a4916ed56 (patch)
tree7192898cb782bbb0b9b13bd8d6341fe4434f0f31 /sqlglot
parentReleasing debian version 10.0.8-1. (diff)
downloadsqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.tar.xz
sqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.zip
Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/bigquery.py20
-rw-r--r--sqlglot/dialects/clickhouse.py13
-rw-r--r--sqlglot/dialects/dialect.py21
-rw-r--r--sqlglot/dialects/drill.py2
-rw-r--r--sqlglot/dialects/hive.py14
-rw-r--r--sqlglot/dialects/mysql.py1
-rw-r--r--sqlglot/dialects/oracle.py11
-rw-r--r--sqlglot/dialects/postgres.py48
-rw-r--r--sqlglot/dialects/presto.py14
-rw-r--r--sqlglot/dialects/redshift.py18
-rw-r--r--sqlglot/dialects/snowflake.py25
-rw-r--r--sqlglot/dialects/spark.py2
-rw-r--r--sqlglot/dialects/sqlite.py18
-rw-r--r--sqlglot/dialects/tsql.py41
-rw-r--r--sqlglot/errors.py41
-rw-r--r--sqlglot/executor/env.py1
-rw-r--r--sqlglot/executor/python.py46
-rw-r--r--sqlglot/expressions.py111
-rw-r--r--sqlglot/generator.py120
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py59
-rw-r--r--sqlglot/optimizer/lower_identities.py92
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py36
-rw-r--r--sqlglot/parser.py321
-rw-r--r--sqlglot/planner.py52
-rw-r--r--sqlglot/tokens.py45
-rw-r--r--sqlglot/transforms.py40
28 files changed, 827 insertions, 389 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 50e2d9c..b027ac7 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.0.8"
+__version__ = "10.1.3"
pretty = False
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 4550d65..5b44912 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):
def _returnsproperty_sql(self, expression):
- value = expression.args.get("value")
- if isinstance(value, exp.Schema):
- value = f"{value.this} <{self.expressions(value)}>"
+ this = expression.this
+ if isinstance(this, exp.Schema):
+ this = f"{this.this} <{self.expressions(this)}>"
else:
- value = self.sql(value)
- return f"RETURNS {value}"
+ this = self.sql(this)
+ return f"RETURNS {this}"
def _create_sql(self, expression):
@@ -142,6 +142,11 @@ class BigQuery(Dialect):
),
}
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS,
+ }
+ FUNCTION_PARSERS.pop("TRIM")
+
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
@@ -174,6 +179,7 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
+ exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@@ -200,9 +206,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty,
}
- WITH_PROPERTIES = {
- exp.AnonymousProperty,
- }
+ WITH_PROPERTIES = {exp.Property}
EXPLICIT_UNION = True
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 332b4c1..cbed72e 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -21,14 +21,15 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "FINAL": TokenType.FINAL,
+ "ASOF": TokenType.ASOF,
"DATETIME64": TokenType.DATETIME,
- "INT8": TokenType.TINYINT,
+ "FINAL": TokenType.FINAL,
+ "FLOAT32": TokenType.FLOAT,
+ "FLOAT64": TokenType.DOUBLE,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
- "FLOAT32": TokenType.FLOAT,
- "FLOAT64": TokenType.DOUBLE,
+ "INT8": TokenType.TINYINT,
"TUPLE": TokenType.STRUCT,
}
@@ -38,6 +39,10 @@ class ClickHouse(Dialect):
"MAP": parse_var_map,
}
+ JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
+
+ TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
+
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 8c497ab..c87f8d8 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}"
-def var_map_sql(self, expression):
+def var_map_sql(self, expression, map_func_name="MAP"):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
- return f"MAP({self.format_args(keys, values)})"
+ return f"{map_func_name}({self.format_args(keys, values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
- return f"MAP({self.format_args(*args)})"
+ return f"{map_func_name}({self.format_args(*args)})"
def format_time_lambda(exp_class, dialect, default=None):
@@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
- value = prop and prop.args.get("value")
- if prop and not isinstance(value, exp.Schema):
+ this = prop and prop.this
+ if prop and not isinstance(this, exp.Schema):
schema = expression.this
- columns = {v.name.upper() for v in value.expressions}
+ columns = {v.name.upper() for v in this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
- schema.set(
- "expressions",
- [e for e in schema.expressions if e not in partitions],
- )
- prop.replace(
- exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
- )
+ schema.set("expressions", [e for e in schema.expressions if e not in partitions])
+ prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index eb420aa..358eced 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -153,7 +153,7 @@ class Drill(Dialect):
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
- exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index cff7139..cbb39c2 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -61,9 +61,7 @@ def _array_sort(self, expression):
def _property_sql(self, expression):
- key = expression.name
- value = self.sql(expression, "value")
- return f"'{key}'={value}"
+ return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix(self, expression):
@@ -250,7 +248,7 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
- exp.AnonymousProperty: _property_sql,
+ exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
@@ -262,7 +260,7 @@ class Hive(Dialect):
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
- exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
+ exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
@@ -285,7 +283,7 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
+ exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@@ -298,11 +296,11 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
- exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
+ exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
- WITH_PROPERTIES = {exp.AnonymousProperty}
+ WITH_PROPERTIES = {exp.Property}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 93a60f4..7627b6e 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -453,6 +453,7 @@ class MySQL(Dialect):
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
+ exp.LikeProperty,
}
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 870d2b9..ceaf9ba 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -1,7 +1,7 @@
from __future__ import annotations
-from sqlglot import exp, generator, tokens, transforms
-from sqlglot.dialects.dialect import Dialect, no_ilike_sql
+from sqlglot import exp, generator, parser, tokens, transforms
+from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
@@ -37,6 +37,12 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
+ class Parser(parser.Parser):
+ FUNCTIONS = {
+ **parser.Parser.FUNCTIONS,
+ "DECODE": exp.Matches.from_arg_list,
+ }
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -58,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
+ exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 4353164..1cb5025 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -74,6 +74,27 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
+def _string_agg_sql(self, expression):
+ expression = expression.copy()
+ separator = expression.args.get("separator") or exp.Literal.string(",")
+
+ order = ""
+ this = expression.this
+ if isinstance(this, exp.Order):
+ if this.this:
+ this = this.this
+ this.pop()
+ order = self.sql(expression.this) # Order has a leading space
+
+ return f"STRING_AGG({self.format_args(this, separator)}{order})"
+
+
+def _datatype_sql(self, expression):
+ if expression.this == exp.DataType.Type.ARRAY:
+ return f"{self.expressions(expression, flat=True)}[]"
+ return self.datatype_sql(expression)
+
+
def _auto_increment_to_serial(expression):
auto = expression.find(exp.AutoIncrementColumnConstraint)
@@ -191,25 +212,27 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
- "BY DEFAULT": TokenType.BY_DEFAULT,
- "IDENTITY": TokenType.IDENTITY,
- "GENERATED": TokenType.GENERATED,
- "DOUBLE PRECISION": TokenType.DOUBLE,
- "BIGSERIAL": TokenType.BIGSERIAL,
- "SERIAL": TokenType.SERIAL,
- "SMALLSERIAL": TokenType.SMALLSERIAL,
- "UUID": TokenType.UUID,
- "TEMP": TokenType.TEMPORARY,
- "BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
+ "BIGSERIAL": TokenType.BIGSERIAL,
+ "BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
+ "DOUBLE PRECISION": TokenType.DOUBLE,
+ "GENERATED": TokenType.GENERATED,
+ "GRANT": TokenType.COMMAND,
+ "HSTORE": TokenType.HSTORE,
+ "IDENTITY": TokenType.IDENTITY,
+ "JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND,
- "GRANT": TokenType.COMMAND,
+ "SERIAL": TokenType.SERIAL,
+ "SMALLSERIAL": TokenType.SMALLSERIAL,
+ "TEMP": TokenType.TEMPORARY,
+ "UUID": TokenType.UUID,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@@ -265,4 +288,7 @@ class Postgres(Dialect):
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
+ exp.DataType: _datatype_sql,
+ exp.GroupConcat: _string_agg_sql,
+ exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 9d5cc11..1a09037 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -171,16 +171,7 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
- ROOT_PROPERTIES = {
- exp.SchemaCommentProperty,
- }
-
- WITH_PROPERTIES = {
- exp.PartitionedByProperty,
- exp.FileFormatProperty,
- exp.AnonymousProperty,
- exp.TableFormatProperty,
- }
+ ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -231,7 +222,8 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
+ exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
+ exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index a9b12fb..cd50979 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp
+from sqlglot import exp, transforms
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@@ -18,12 +18,14 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
+ "COPY": TokenType.COMMAND,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
+ "UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
@@ -35,3 +37,17 @@ class Redshift(Postgres):
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
+
+ ROOT_PROPERTIES = {
+ exp.DistKeyProperty,
+ exp.SortKeyProperty,
+ exp.DistStyleProperty,
+ }
+
+ TRANSFORMS = {
+ **Postgres.Generator.TRANSFORMS, # type: ignore
+ **transforms.ELIMINATE_DISTINCT_ON, # type: ignore
+ exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
+ exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.DistStyleProperty: lambda self, e: self.naked_property(e),
+ }
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index a96bd80..46155ff 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
inline_array_sql,
rename_func,
+ var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
@@ -100,6 +101,14 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)
+def _datatype_sql(self, expression):
+ if expression.this == exp.DataType.Type.ARRAY:
+ return "ARRAY"
+ elif expression.this == exp.DataType.Type.MAP:
+ return "OBJECT"
+ return self.datatype_sql(expression)
+
+
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@@ -142,6 +151,8 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
+ "DECODE": exp.Matches.from_arg_list,
+ "OBJECT_CONSTRUCT": parser.parse_var_map,
}
FUNCTION_PARSERS = {
@@ -195,16 +206,20 @@ class Snowflake(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
+ exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
+ exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.Matches: rename_func("DECODE"),
+ exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
- exp.Array: inline_array_sql,
- exp.StrPosition: rename_func("POSITION"),
- exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
- exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
+ exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 4e404b8..16083d1 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -98,7 +98,7 @@ class Spark(Hive):
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
- exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
+ exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 87b98a5..bbb752b 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
+# https://www.sqlite.org/lang_aggfunc.html#group_concat
+def _group_concat_sql(self, expression):
+ this = expression.this
+ distinct = expression.find(exp.Distinct)
+ if distinct:
+ this = distinct.expressions[0]
+ distinct = "DISTINCT "
+
+ if isinstance(expression.this, exp.Order):
+ self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
+ if expression.this.this and not distinct:
+ this = expression.this.this
+
+ separator = expression.args.get("separator")
+ return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
+
+
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@@ -62,6 +79,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
+ exp.GroupConcat: _group_concat_sql,
}
def transaction_sql(self, expression):
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index d3b83de..07ce38b 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = {
"mm": "%B",
"m": "%B",
}
+
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = {
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
+
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
-def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
+def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=seq_get(args, 1),
@@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
-def parse_format(args):
+def _parse_format(args):
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
@@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
-def generate_format_sql(self, e):
+def _format_sql(self, e):
fmt = (
e.args["format"]
if isinstance(e, exp.NumberToStr)
@@ -87,6 +89,28 @@ def generate_format_sql(self, e):
return f"FORMAT({self.format_args(e.this, fmt)})"
+def _string_agg_sql(self, e):
+ e = e.copy()
+
+ this = e.this
+ distinct = e.find(exp.Distinct)
+ if distinct:
+ # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
+ self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
+ this = distinct.expressions[0]
+ distinct.pop()
+
+ order = ""
+ if isinstance(e.this, exp.Order):
+ if e.this.this:
+ this = e.this.this
+ e.this.this.pop()
+ order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
+
+ separator = e.args.get("separator") or exp.Literal.string(",")
+ return f"STRING_AGG({self.format_args(this, separator)}){order}"
+
+
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
@@ -228,14 +252,14 @@ class TSQL(Dialect):
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
- "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
- "DATEPART": tsql_format_time_lambda(exp.TimeToStr),
+ "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
+ "DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
- "FORMAT": parse_format,
+ "FORMAT": _parse_format,
}
VAR_LENGTH_DATATYPES = {
@@ -298,6 +322,7 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
- exp.NumberToStr: generate_format_sql,
- exp.TimeToStr: generate_format_sql,
+ exp.NumberToStr: _format_sql,
+ exp.TimeToStr: _format_sql,
+ exp.GroupConcat: _string_agg_sql,
}
diff --git a/sqlglot/errors.py b/sqlglot/errors.py
index 23a08bd..b5ef5ad 100644
--- a/sqlglot/errors.py
+++ b/sqlglot/errors.py
@@ -22,7 +22,40 @@ class UnsupportedError(SqlglotError):
class ParseError(SqlglotError):
- pass
+ def __init__(
+ self,
+ message: str,
+ errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None,
+ ):
+ super().__init__(message)
+ self.errors = errors or []
+
+ @classmethod
+ def new(
+ cls,
+ message: str,
+ description: t.Optional[str] = None,
+ line: t.Optional[int] = None,
+ col: t.Optional[int] = None,
+ start_context: t.Optional[str] = None,
+ highlight: t.Optional[str] = None,
+ end_context: t.Optional[str] = None,
+ into_expression: t.Optional[str] = None,
+ ) -> ParseError:
+ return cls(
+ message,
+ [
+ {
+ "description": description,
+ "line": line,
+ "col": col,
+ "start_context": start_context,
+ "highlight": highlight,
+ "end_context": end_context,
+ "into_expression": into_expression,
+ }
+ ],
+ )
class TokenError(SqlglotError):
@@ -41,9 +74,13 @@ class ExecuteError(SqlglotError):
pass
-def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
+def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
if remaining > 0:
msg.append(f"... and {remaining} more")
return "\n\n".join(msg)
+
+
+def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]:
+ return [e_dict for error in errors for e_dict in error.errors]
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index ed80cc9..e6cfcdd 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -122,7 +122,6 @@ def interval(this, unit):
ENV = {
- "__builtins__": {},
"exp": exp,
# aggs
"SUM": filter_nulls(sum),
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index cb2543c..908b80a 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -115,6 +115,9 @@ class PythonExecutor:
sink = self.table(context.columns)
for reader in table_iter:
+ if len(sink) >= step.limit:
+ break
+
if condition and not context.eval(condition):
continue
@@ -123,9 +126,6 @@ class PythonExecutor:
else:
sink.append(reader.row)
- if len(sink) >= step.limit:
- break
-
return self.context({step.name: sink})
def static(self):
@@ -288,21 +288,32 @@ class PythonExecutor:
end = 1
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
+ condition = self.generate(step.condition)
- for i in range(length):
- context.set_index(i)
- key = context.eval_tuple(group_by)
- group = key if group is None else group
- end += 1
- if key != group:
- context.set_range(start, end - 2)
- table.append(group + context.eval_tuple(aggregations))
- group = key
- start = end - 2
- if i == length - 1:
- context.set_range(start, end - 1)
+ def add_row():
+ if not condition or context.eval(condition):
table.append(group + context.eval_tuple(aggregations))
+ if length:
+ for i in range(length):
+ context.set_index(i)
+ key = context.eval_tuple(group_by)
+ group = key if group is None else group
+ end += 1
+ if key != group:
+ context.set_range(start, end - 2)
+ add_row()
+ group = key
+ start = end - 2
+ if len(table.rows) >= step.limit:
+ break
+ if i == length - 1:
+ context.set_range(start, end - 1)
+ add_row()
+ elif step.limit > 0:
+ context.set_range(0, 0)
+ table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
+
context = self.context({step.name: table, **{name: table for name in context.tables}})
if step.projections:
@@ -311,11 +322,9 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
-
projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns)
-
for reader, ctx in context:
sink.append(reader.row + ctx.eval_tuple(projections))
@@ -401,8 +410,9 @@ class Python(Dialect):
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
+ exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
- exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
+ exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index beafca8..96b32f1 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "type", "comment")
+ __slots__ = ("args", "parent", "arg_key", "type", "comments")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
- self.comment = None
+ self.comments = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@@ -88,19 +88,6 @@ class Expression(metaclass=_Expression):
return field.this
return ""
- def find_comment(self, key: str) -> str:
- """
- Finds the comment that is attached to a specified child node.
-
- Args:
- key: the key of the target child node (e.g. "this", "expression", etc).
-
- Returns:
- The comment attached to the child node, or the empty string, if it doesn't exist.
- """
- field = self.args.get(key)
- return field.comment if isinstance(field, Expression) else ""
-
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@@ -137,7 +124,7 @@ class Expression(metaclass=_Expression):
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
- copy.comment = self.comment
+ copy.comments = self.comments
copy.type = self.type
return copy
@@ -369,7 +356,7 @@ class Expression(metaclass=_Expression):
)
for k, vs in self.args.items()
}
- args["comment"] = self.comment
+ args["comments"] = self.comments
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
@@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind):
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
- pass
+ arg_types = {"desc": False}
class UniqueColumnConstraint(ColumnConstraintKind):
@@ -819,6 +806,12 @@ class Unique(Expression):
arg_types = {"expressions": True}
+# https://www.postgresql.org/docs/9.1/sql-selectinto.html
+# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
+class Into(Expression):
+ arg_types = {"this": True, "temporary": False, "unlogged": False}
+
+
class From(Expression):
arg_types = {"expressions": True}
@@ -1065,67 +1058,67 @@ class Property(Expression):
class TableFormatProperty(Property):
- pass
+ arg_types = {"this": True}
class PartitionedByProperty(Property):
- pass
+ arg_types = {"this": True}
class FileFormatProperty(Property):
- pass
+ arg_types = {"this": True}
class DistKeyProperty(Property):
- pass
+ arg_types = {"this": True}
class SortKeyProperty(Property):
- pass
+ arg_types = {"this": True, "compound": False}
class DistStyleProperty(Property):
- pass
+ arg_types = {"this": True}
+
+
+class LikeProperty(Property):
+ arg_types = {"this": True, "expressions": False}
class LocationProperty(Property):
- pass
+ arg_types = {"this": True}
class EngineProperty(Property):
- pass
+ arg_types = {"this": True}
class AutoIncrementProperty(Property):
- pass
+ arg_types = {"this": True}
class CharacterSetProperty(Property):
- arg_types = {"this": True, "value": True, "default": True}
+ arg_types = {"this": True, "default": True}
class CollateProperty(Property):
- pass
+ arg_types = {"this": True}
class SchemaCommentProperty(Property):
- pass
-
-
-class AnonymousProperty(Property):
- pass
+ arg_types = {"this": True}
class ReturnsProperty(Property):
- arg_types = {"this": True, "value": True, "is_table": False}
+ arg_types = {"this": True, "is_table": False}
class LanguageProperty(Property):
- pass
+ arg_types = {"this": True}
class ExecuteAsProperty(Property):
- pass
+ arg_types = {"this": True}
class VolatilityProperty(Property):
@@ -1135,27 +1128,36 @@ class VolatilityProperty(Property):
class Properties(Expression):
arg_types = {"expressions": True}
- PROPERTY_KEY_MAPPING = {
+ NAME_TO_PROPERTY = {
"AUTO_INCREMENT": AutoIncrementProperty,
- "CHARACTER_SET": CharacterSetProperty,
+ "CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
+ "DISTKEY": DistKeyProperty,
+ "DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
+ "EXECUTE AS": ExecuteAsProperty,
"FORMAT": FileFormatProperty,
+ "LANGUAGE": LanguageProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
- "TABLE_FORMAT": TableFormatProperty,
- "DISTKEY": DistKeyProperty,
- "DISTSTYLE": DistStyleProperty,
+ "RETURNS": ReturnsProperty,
"SORTKEY": SortKeyProperty,
+ "TABLE_FORMAT": TableFormatProperty,
}
+ PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
+
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
- property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
- expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
+ property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
+ if property_cls:
+ expressions.append(property_cls(this=convert(value)))
+ else:
+ expressions.append(Property(this=Literal.string(key), value=convert(value)))
+
return cls(expressions=expressions)
@@ -1383,6 +1385,7 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
+ "into": False,
"from": False,
**QUERY_MODIFIERS,
}
@@ -2015,6 +2018,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
+ JSONB = auto()
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
@@ -2029,6 +2033,7 @@ class DataType(Expression):
STRUCT = auto()
NULLABLE = auto()
HLLSKETCH = auto()
+ HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
@@ -2109,7 +2114,7 @@ class Transaction(Command):
class Commit(Command):
- arg_types = {} # type: ignore
+ arg_types = {"chain": False}
class Rollback(Command):
@@ -2442,7 +2447,7 @@ class ArrayFilter(Func):
class ArraySize(Func):
- pass
+ arg_types = {"this": True, "expression": False}
class ArraySort(Func):
@@ -2726,6 +2731,16 @@ class VarMap(Func):
is_var_len_args = True
+class Matches(Func):
+ """Oracle/Snowflake decode.
+ https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
+ Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
+ """
+
+ arg_types = {"this": True, "expressions": True}
+ is_var_len_args = True
+
+
class Max(AggFunc):
pass
@@ -2785,6 +2800,10 @@ class Round(Func):
arg_types = {"this": True, "decimals": False}
+class RowNumber(Func):
+ arg_types: t.Dict[str, t.Any] = {}
+
+
class SafeDivide(Func):
arg_types = {"this": True, "expression": True}
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index ffb34eb..47774fc 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,19 +1,16 @@
from __future__ import annotations
import logging
-import re
import typing as t
from sqlglot import exp
-from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
+from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
-NEWLINE_RE = re.compile("\r\n?|\n")
-
class Generator:
"""
@@ -58,11 +55,11 @@ class Generator:
"""
TRANSFORMS = {
- exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
+ exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@@ -97,16 +94,17 @@ class Generator:
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
+ exp.LikeProperty,
}
WITH_PROPERTIES = {
- exp.AnonymousProperty,
+ exp.Property,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
}
- WITH_SEPARATED_COMMENTS = (exp.Select,)
+ WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
__slots__ = (
"time_mapping",
@@ -211,7 +209,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
- raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
+ raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
return sql
@@ -226,25 +224,24 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
- def maybe_comment(self, sql, expression, single_line=False):
- comment = expression.comment if self._comments else None
-
- if not comment:
- return sql
-
+ def pad_comment(self, comment):
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
+ return comment
- if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
- return f"/*{comment}*/{self.sep()}{sql}"
+ def maybe_comment(self, sql, expression):
+ comments = expression.comments if self._comments else None
- if not self.pretty:
- return f"{sql} /*{comment}*/"
+ if not comments:
+ return sql
+
+ sep = "\n" if self.pretty else " "
+ comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
- if not NEWLINE_RE.search(comment):
- return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
+ if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
+ return f"{comments}{self.sep()}{sql}"
- return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
+ return f"{sql} {comments}"
def wrap(self, expression):
this_sql = self.indent(
@@ -387,8 +384,11 @@ class Generator:
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
- def primarykeycolumnconstraint_sql(self, _):
- return "PRIMARY KEY"
+ def primarykeycolumnconstraint_sql(self, expression):
+ desc = expression.args.get("desc")
+ if desc is not None:
+ return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
+ return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, _):
return "UNIQUE"
@@ -546,36 +546,33 @@ class Generator:
def root_properties(self, properties):
if properties.expressions:
- return self.sep() + self.expressions(
- properties,
- indent=False,
- sep=" ",
- )
+ return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(self, properties, prefix="", sep=", "):
if properties.expressions:
- expressions = self.expressions(
- properties,
- sep=sep,
- indent=False,
- )
+ expressions = self.expressions(properties, sep=sep, indent=False)
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
def with_properties(self, properties):
- return self.properties(
- properties,
- prefix="WITH",
- )
+ return self.properties(properties, prefix="WITH")
def property_sql(self, expression):
- if isinstance(expression.this, exp.Literal):
- key = expression.this.this
- else:
- key = expression.name
- value = self.sql(expression, "value")
- return f"{key}={value}"
+ property_cls = expression.__class__
+ if property_cls == exp.Property:
+ return f"{expression.name}={self.sql(expression, 'value')}"
+
+ property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
+ if not property_name:
+ self.unsupported(f"Unsupported property {property_name}")
+
+ return f"{property_name}={self.sql(expression, 'this')}"
+
+ def likeproperty_sql(self, expression):
+ options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
+ options = f" {options}" if options else ""
+ return f"LIKE {self.sql(expression, 'this')}{options}"
def insert_sql(self, expression):
overwrite = expression.args.get("overwrite")
@@ -700,6 +697,11 @@ class Generator:
def var_sql(self, expression):
return self.sql(expression, "this")
+ def into_sql(self, expression):
+ temporary = " TEMPORARY" if expression.args.get("temporary") else ""
+ unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
+ return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
+
def from_sql(self, expression):
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
@@ -883,6 +885,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
+ self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@@ -1061,6 +1064,11 @@ class Generator:
else:
return f"TRIM({target})"
+ def concat_sql(self, expression):
+ if len(expression.expressions) == 1:
+ return self.sql(expression.expressions[0])
+ return self.function_fallback_sql(expression)
+
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@@ -1125,7 +1133,10 @@ class Generator:
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
- return f"-{self.sql(expression, 'this')}"
+ # This makes sure we don't convert "- - 5" to "--5", which is a comment
+ this_sql = self.sql(expression, "this")
+ sep = " " if this_sql[0] == "-" else ""
+ return f"-{sep}{this_sql}"
def not_sql(self, expression):
return f"NOT {self.sql(expression, 'this')}"
@@ -1191,8 +1202,12 @@ class Generator:
def transaction_sql(self, *_):
return "BEGIN"
- def commit_sql(self, *_):
- return "COMMIT"
+ def commit_sql(self, expression):
+ chain = expression.args.get("chain")
+ if chain is not None:
+ chain = " AND CHAIN" if chain else " AND NO CHAIN"
+
+ return f"COMMIT{chain or ''}"
def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
@@ -1334,15 +1349,15 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
- comment = self.maybe_comment("", e, single_line=True)
+ comments = self.maybe_comment("", e)
if self.pretty:
if self._leading_comma:
- result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
+ result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
else:
- result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
+ result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
else:
- result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
+ result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
@@ -1354,7 +1369,10 @@ class Generator:
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
- return f"{expression.name} {self.sql(expression, 'value')}"
+ property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
+ if not property_name:
+ self.unsupported(f"Unsupported property {expression.__class__.__name__}")
+ return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression, op):
this = self.sql(expression, "this")
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 8704e90..39e252c 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
+ if scope is cte_scope:
+ # Don't try to eliminate this CTE itself
+ continue
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
@@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
+ if scope.is_cte:
+ return _eliminate_cte(scope, existing_ctes, taken)
+
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
@@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
+ parent = scope.expression.parent
+ name, cte = _new_cte(scope, existing_ctes, taken)
+
+ table = exp.alias_(exp.table_(name), alias=parent.alias or name)
+ parent.replace(table)
+
+ return cte
+
+
+def _eliminate_cte(scope, existing_ctes, taken):
+ parent = scope.expression.parent
+ name, cte = _new_cte(scope, existing_ctes, taken)
+
+ with_ = parent.parent
+ parent.pop()
+ if not with_.expressions:
+ with_.pop()
+
+ # Rename references to this CTE
+ for child_scope in scope.parent.traverse():
+ for table, source in child_scope.selected_sources.values():
+ if source is scope:
+ new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
+ table.replace(new_table)
+
+ return cte
+
+
+def _new_cte(scope, existing_ctes, taken):
+ """
+ Returns:
+ tuple of (name, cte)
+ where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
+ If this CTE duplicates an existing CTE, `cte` will be None.
+ """
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
- name = alias = parent.alias
+ name = parent.alias
- if not alias:
- name = alias = find_new_name(taken=taken, base="cte")
+ if not name:
+ name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
- elif taken.get(alias):
- name = find_new_name(taken=taken, base=alias)
+ elif taken.get(name):
+ name = find_new_name(taken=taken, base=name)
taken[name] = scope
- table = exp.alias_(exp.table_(name), alias=alias)
- parent.replace(table)
-
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
- return exp.CTE(
+ cte = exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)
+ else:
+ cte = None
+ return name, cte
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
new file mode 100644
index 0000000..1cc76cf
--- /dev/null
+++ b/sqlglot/optimizer/lower_identities.py
@@ -0,0 +1,92 @@
+from sqlglot import exp
+from sqlglot.helper import ensure_collection
+
+
+def lower_identities(expression):
+ """
+ Convert all unquoted identifiers to lower case.
+
+ Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
+ >>> lower_identities(expression).sql()
+ 'SELECT bar.a AS A FROM "Foo".bar'
+
+ Args:
+ expression (sqlglot.Expression): expression to quote
+ Returns:
+ sqlglot.Expression: quoted expression
+ """
+ # We need to leave the output aliases unchanged, so the selects need special handling
+ _lower_selects(expression)
+
+ # These clauses can reference output aliases and also need special handling
+ _lower_order(expression)
+ _lower_having(expression)
+
+ # We've already handled these args, so don't traverse into them
+ traversed = {"expressions", "order", "having"}
+
+ if isinstance(expression, exp.Subquery):
+ # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
+ lower_identities(expression.this)
+ traversed |= {"this"}
+
+ if isinstance(expression, exp.Union):
+ # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
+ lower_identities(expression.left)
+ lower_identities(expression.right)
+ traversed |= {"this", "expression"}
+
+ for k, v in expression.args.items():
+ if k in traversed:
+ continue
+
+ for child in ensure_collection(v):
+ if isinstance(child, exp.Expression):
+ child.transform(_lower, copy=False)
+
+ return expression
+
+
+def _lower_selects(expression):
+ for e in expression.expressions:
+ # Leave output aliases as-is
+ e.unalias().transform(_lower, copy=False)
+
+
+def _lower_order(expression):
+ order = expression.args.get("order")
+
+ if not order:
+ return
+
+ output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
+
+ for ordered in order.expressions:
+ # Don't lower references to output aliases
+ if not (
+ isinstance(ordered.this, exp.Column)
+ and not ordered.this.table
+ and ordered.this.name in output_aliases
+ ):
+ ordered.transform(_lower, copy=False)
+
+
+def _lower_having(expression):
+ having = expression.args.get("having")
+
+ if not having:
+ return
+
+ # Don't lower references to output aliases
+ for agg in having.find_all(exp.AggFunc):
+ agg.transform(_lower, copy=False)
+
+
+def _lower(node):
+ if isinstance(node, exp.Identifier) and not node.quoted:
+ node.set("this", node.this.lower())
+ return node
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index d0e38cd..6819717 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
@@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
+ lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index dbd680b..2046917 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -1,16 +1,15 @@
import itertools
from sqlglot import exp
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
- Convert the subquery into a group by so it is not a many to many left join.
- Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
- Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
+ Convert scalar subqueries into cross joins.
+ Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
@@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
+ if not parent:
+ continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
- else:
+ elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
- predicate = select.find_ancestor(exp.In, exp.Any)
+ if len(select.selects) > 1:
+ return
+
+ predicate = select.find_ancestor(exp.Condition)
+ alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select:
return
- if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
+ # this subquery returns a scalar and can just be converted to a cross join
+ if not isinstance(predicate, (exp.In, exp.Any)):
+ having = predicate.find_ancestor(exp.Having)
+ column = exp.column(select.selects[0].alias_or_name, alias)
+ if having and having.parent_select is parent_select:
+ column = exp.Max(this=column)
+ _replace(select.parent, column)
+
+ parent_select.join(
+ select,
+ join_type="CROSS",
+ join_alias=alias,
+ copy=False,
+ )
+ return
+
+ if select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
@@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate)
value = select.selects[0]
- alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 5b93510..bdf0d2d 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -4,7 +4,7 @@ import logging
import typing as t
from sqlglot import exp
-from sqlglot.errors import ErrorLevel, ParseError, concat_errors
+from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@@ -104,6 +104,7 @@ class Parser(metaclass=_Parser):
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
+ TokenType.JSONB,
TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
@@ -115,6 +116,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
+ TokenType.HSTORE,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
@@ -153,6 +155,7 @@ class Parser(metaclass=_Parser):
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMIT,
+ TokenType.COMPOUND,
TokenType.CONSTRAINT,
TokenType.CURRENT_TIME,
TokenType.DEFAULT,
@@ -194,6 +197,7 @@ class Parser(metaclass=_Parser):
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RETURNS,
+ TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
TokenType.SCHEMA_COMMENT,
@@ -213,6 +217,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
+ TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
@@ -400,9 +405,17 @@ class Parser(metaclass=_Parser):
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
+ TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}
+ UNARY_PARSERS = {
+ TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op
+ TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
+ TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
+ TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
+ }
+
PRIMARY_PARSERS = {
TokenType.STRING: lambda self, token: self.expression(
exp.Literal, this=token.text, is_string=True
@@ -446,19 +459,20 @@ class Parser(metaclass=_Parser):
}
PROPERTY_PARSERS = {
- TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
- TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
- TokenType.LOCATION: lambda self: self.expression(
- exp.LocationProperty,
- this=exp.Literal.string("LOCATION"),
- value=self._parse_string(),
+ TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
+ exp.AutoIncrementProperty
),
+ TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
+ TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
- TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
- TokenType.STORED: lambda self: self._parse_stored(),
+ TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
+ exp.SchemaCommentProperty
+ ),
+ TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.DISTKEY: lambda self: self._parse_distkey(),
- TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
+ TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
+ TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@@ -468,7 +482,7 @@ class Parser(metaclass=_Parser):
),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
- TokenType.EXECUTE: lambda self: self._parse_execute_as(),
+ TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
@@ -489,6 +503,7 @@ class Parser(metaclass=_Parser):
),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(),
+ TokenType.LIKE: lambda self: self._parse_create_like(),
}
NO_PAREN_FUNCTION_PARSERS = {
@@ -505,6 +520,7 @@ class Parser(metaclass=_Parser):
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False),
+ "STRING_AGG": lambda self: self._parse_string_agg(),
}
QUERY_MODIFIER_PARSERS = {
@@ -556,7 +572,7 @@ class Parser(metaclass=_Parser):
"_curr",
"_next",
"_prev",
- "_prev_comment",
+ "_prev_comments",
"_show_trie",
"_set_trie",
)
@@ -589,7 +605,7 @@ class Parser(metaclass=_Parser):
self._curr = None
self._next = None
self._prev = None
- self._prev_comment = None
+ self._prev_comments = None
def parse(self, raw_tokens, sql=None):
"""
@@ -608,6 +624,7 @@ class Parser(metaclass=_Parser):
)
def parse_into(self, expression_types, raw_tokens, sql=None):
+ errors = []
for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
@@ -615,8 +632,12 @@ class Parser(metaclass=_Parser):
try:
return self._parse(parser, raw_tokens, sql)
except ParseError as e:
- error = e
- raise ParseError(f"Failed to parse into {expression_types}") from error
+ e.errors[0]["into_expression"] = expression_type
+ errors.append(e)
+ raise ParseError(
+ f"Failed to parse into {expression_types}",
+ errors=merge_errors(errors),
+ ) from errors[-1]
def _parse(self, parse_method, raw_tokens, sql=None):
self.reset()
@@ -650,7 +671,10 @@ class Parser(metaclass=_Parser):
for error in self.errors:
logger.error(str(error))
elif self.error_level == ErrorLevel.RAISE and self.errors:
- raise ParseError(concat_errors(self.errors, self.max_errors))
+ raise ParseError(
+ concat_messages(self.errors, self.max_errors),
+ errors=merge_errors(self.errors),
+ )
def raise_error(self, message, token=None):
token = token or self._curr or self._prev or Token.string("")
@@ -659,19 +683,27 @@ class Parser(metaclass=_Parser):
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
- error = ParseError(
+ error = ParseError.new(
f"{message}. Line {token.line}, Col: {token.col}.\n"
- f" {start_context}\033[4m{highlight}\033[0m{end_context}"
+ f" {start_context}\033[4m{highlight}\033[0m{end_context}",
+ description=message,
+ line=token.line,
+ col=token.col,
+ start_context=start_context,
+ highlight=highlight,
+ end_context=end_context,
)
if self.error_level == ErrorLevel.IMMEDIATE:
raise error
self.errors.append(error)
- def expression(self, exp_class, **kwargs):
+ def expression(self, exp_class, comments=None, **kwargs):
instance = exp_class(**kwargs)
- if self._prev_comment:
- instance.comment = self._prev_comment
- self._prev_comment = None
+ if self._prev_comments:
+ instance.comments = self._prev_comments
+ self._prev_comments = None
+ if comments:
+ instance.comments = comments
self.validate_expression(instance)
return instance
@@ -714,10 +746,10 @@ class Parser(metaclass=_Parser):
self._next = seq_get(self._tokens, self._index + 1)
if self._index > 0:
self._prev = self._tokens[self._index - 1]
- self._prev_comment = self._prev.comment
+ self._prev_comments = self._prev.comments
else:
self._prev = None
- self._prev_comment = None
+ self._prev_comments = None
def _retreat(self, index):
self._advance(index - self._index)
@@ -768,7 +800,7 @@ class Parser(metaclass=_Parser):
)
def _parse_create(self):
- replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
+ replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT)
unique = self._match(TokenType.UNIQUE)
@@ -822,97 +854,57 @@ class Parser(metaclass=_Parser):
def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
+
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(True)
+ if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
+ return self._parse_sortkey(compound=True)
+
if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
- key = self._parse_var().this
+ key = self._parse_var()
self._match(TokenType.EQ)
-
- return self.expression(
- exp.AnonymousProperty,
- this=exp.Literal.string(key),
- value=self._parse_column(),
- )
+ return self.expression(exp.Property, this=key, value=self._parse_column())
return None
def _parse_property_assignment(self, exp_class):
- prop = self._prev.text
self._match(TokenType.EQ)
- return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
+ self._match(TokenType.ALIAS)
+ return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
- this=exp.Literal.string("PARTITIONED_BY"),
- value=self._parse_schema() or self._parse_bracket(self._parse_field()),
- )
-
- def _parse_stored(self):
- self._match(TokenType.ALIAS)
- self._match(TokenType.EQ)
- return self.expression(
- exp.FileFormatProperty,
- this=exp.Literal.string("FORMAT"),
- value=exp.Literal.string(self._parse_var_or_string().name),
+ this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_distkey(self):
- self._match_l_paren()
- this = exp.Literal.string("DISTKEY")
- value = exp.Literal.string(self._parse_var().name)
- self._match_r_paren()
- return self.expression(
- exp.DistKeyProperty,
- this=this,
- value=value,
- )
+ return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
- def _parse_sortkey(self):
- self._match_l_paren()
- this = exp.Literal.string("SORTKEY")
- value = exp.Literal.string(self._parse_var().name)
- self._match_r_paren()
- return self.expression(
- exp.SortKeyProperty,
- this=this,
- value=value,
- )
-
- def _parse_diststyle(self):
- this = exp.Literal.string("DISTSTYLE")
- value = exp.Literal.string(self._parse_var().name)
- return self.expression(
- exp.DistStyleProperty,
- this=this,
- value=value,
- )
-
- def _parse_auto_increment(self):
- self._match(TokenType.EQ)
- return self.expression(
- exp.AutoIncrementProperty,
- this=exp.Literal.string("AUTO_INCREMENT"),
- value=self._parse_number(),
- )
+ def _parse_create_like(self):
+ table = self._parse_table(schema=True)
+ options = []
+ while self._match_texts(("INCLUDING", "EXCLUDING")):
+ options.append(
+ self.expression(
+ exp.Property,
+ this=self._prev.text.upper(),
+ value=exp.Var(this=self._parse_id_var().this.upper()),
+ )
+ )
+ return self.expression(exp.LikeProperty, this=table, expressions=options)
- def _parse_schema_comment(self):
- self._match(TokenType.EQ)
+ def _parse_sortkey(self, compound=False):
return self.expression(
- exp.SchemaCommentProperty,
- this=exp.Literal.string("COMMENT"),
- value=self._parse_string(),
+ exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
)
def _parse_character_set(self, default=False):
self._match(TokenType.EQ)
return self.expression(
- exp.CharacterSetProperty,
- this=exp.Literal.string("CHARACTER_SET"),
- value=self._parse_var_or_string(),
- default=default,
+ exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
def _parse_returns(self):
@@ -931,20 +923,7 @@ class Parser(metaclass=_Parser):
else:
value = self._parse_types()
- return self.expression(
- exp.ReturnsProperty,
- this=exp.Literal.string("RETURNS"),
- value=value,
- is_table=is_table,
- )
-
- def _parse_execute_as(self):
- self._match(TokenType.ALIAS)
- return self.expression(
- exp.ExecuteAsProperty,
- this=exp.Literal.string("EXECUTE AS"),
- value=self._parse_var(),
- )
+ return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
def _parse_properties(self):
properties = []
@@ -956,7 +935,7 @@ class Parser(metaclass=_Parser):
properties.extend(
self._parse_wrapped_csv(
lambda: self.expression(
- exp.AnonymousProperty,
+ exp.Property,
this=self._parse_string(),
value=self._match(TokenType.EQ) and self._parse_string(),
)
@@ -1076,7 +1055,12 @@ class Parser(metaclass=_Parser):
options = []
if self._match(TokenType.OPTIONS):
- options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
+ self._match_l_paren()
+ k = self._parse_string()
+ self._match(TokenType.EQ)
+ v = self._parse_string()
+ options = [k, v]
+ self._match_r_paren()
self._match(TokenType.ALIAS)
return self.expression(
@@ -1116,7 +1100,7 @@ class Parser(metaclass=_Parser):
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
- comment = self._prev_comment
+ comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
@@ -1141,10 +1125,16 @@ class Parser(metaclass=_Parser):
expressions=expressions,
limit=limit,
)
- this.comment = comment
+ this.comments = comments
+
+ into = self._parse_into()
+ if into:
+ this.set("into", into)
+
from_ = self._parse_from()
if from_:
this.set("from", from_)
+
self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
@@ -1248,11 +1238,24 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Hint, expressions=hints)
return None
+ def _parse_into(self):
+ if not self._match(TokenType.INTO):
+ return None
+
+ temp = self._match(TokenType.TEMPORARY)
+ unlogged = self._match(TokenType.UNLOGGED)
+ self._match(TokenType.TABLE)
+
+ return self.expression(
+ exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
+ )
+
def _parse_from(self):
if not self._match(TokenType.FROM):
return None
-
- return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
+ return self.expression(
+ exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
+ )
def _parse_lateral(self):
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
@@ -1515,7 +1518,9 @@ class Parser(metaclass=_Parser):
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
return None
- return self.expression(exp.Where, this=self._parse_conjunction())
+ return self.expression(
+ exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
+ )
def _parse_group(self, skip_group_by_token=False):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
@@ -1737,12 +1742,8 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_unary, self.FACTOR)
def _parse_unary(self):
- if self._match(TokenType.NOT):
- return self.expression(exp.Not, this=self._parse_equality())
- if self._match(TokenType.TILDA):
- return self.expression(exp.BitwiseNot, this=self._parse_unary())
- if self._match(TokenType.DASH):
- return self.expression(exp.Neg, this=self._parse_unary())
+ if self._match_set(self.UNARY_PARSERS):
+ return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
def _parse_type(self):
@@ -1775,17 +1776,6 @@ class Parser(metaclass=_Parser):
expressions = None
maybe_func = False
- if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- return exp.DataType(
- this=exp.DataType.Type.ARRAY,
- expressions=[exp.DataType.build(type_token.value)],
- nested=True,
- )
-
- if self._match(TokenType.L_BRACKET):
- self._retreat(index)
- return None
-
if self._match(TokenType.L_PAREN):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@@ -1801,6 +1791,17 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
maybe_func = True
+ if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
+ return exp.DataType(
+ this=exp.DataType.Type.ARRAY,
+ expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
+ nested=True,
+ )
+
+ if self._match(TokenType.L_BRACKET):
+ self._retreat(index)
+ return None
+
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@@ -1904,7 +1905,7 @@ class Parser(metaclass=_Parser):
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
- comment = self._prev_comment
+ comments = self._prev_comments
query = self._parse_select()
if query:
@@ -1924,8 +1925,8 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
- if comment:
- this.comment = comment
+ if comments:
+ this.comments = comments
return this
return None
@@ -2098,7 +2099,10 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
- kind = exp.PrimaryKeyColumnConstraint()
+ desc = None
+ if self._match(TokenType.ASC) or self._match(TokenType.DESC):
+ desc = self._prev.token_type == TokenType.DESC
+ kind = exp.PrimaryKeyColumnConstraint(desc=desc)
elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint()
elif self._match(TokenType.GENERATED):
@@ -2189,7 +2193,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")
- this.comment = self._prev_comment
+ this.comments = self._prev_comments
return self._parse_bracket(this)
def _parse_case(self):
@@ -2256,6 +2260,33 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_string_agg(self):
+ if self._match(TokenType.DISTINCT):
+ args = self._parse_csv(self._parse_conjunction)
+ expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
+ else:
+ args = self._parse_csv(self._parse_conjunction)
+ expression = seq_get(args, 0)
+
+ index = self._index
+ if not self._match(TokenType.R_PAREN):
+ # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
+ order = self._parse_order(this=expression)
+ return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
+
+ # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
+ # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
+ # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
+ if not self._match(TokenType.WITHIN_GROUP):
+ self._retreat(index)
+ this = exp.GroupConcat.from_arg_list(args)
+ self.validate_expression(this, args)
+ return this
+
+ self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
+ order = self._parse_order(this=expression)
+ return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
+
def _parse_convert(self, strict):
this = self._parse_column()
if self._match(TokenType.USING):
@@ -2511,8 +2542,8 @@ class Parser(metaclass=_Parser):
items = [parse_result] if parse_result is not None else []
while self._match(sep):
- if parse_result and self._prev_comment is not None:
- parse_result.comment = self._prev_comment
+ if parse_result and self._prev_comments:
+ parse_result.comments = self._prev_comments
parse_result = parse_method()
if parse_result is not None:
@@ -2525,7 +2556,10 @@ class Parser(metaclass=_Parser):
while self._match_set(expressions):
this = self.expression(
- expressions[self._prev.token_type], this=this, expression=parse_method()
+ expressions[self._prev.token_type],
+ this=this,
+ comments=self._prev_comments,
+ expression=parse_method(),
)
return this
@@ -2566,6 +2600,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
def _parse_commit_or_rollback(self):
+ chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@@ -2575,9 +2610,13 @@ class Parser(metaclass=_Parser):
self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var()
+ if self._match(TokenType.AND):
+ chain = not self._match_text_seq("NO")
+ self._match_text_seq("CHAIN")
+
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
- return self.expression(exp.Commit)
+ return self.expression(exp.Commit, chain=chain)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@@ -2651,14 +2690,14 @@ class Parser(metaclass=_Parser):
def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (")
- if expression and self._prev_comment:
- expression.comment = self._prev_comment
+ if expression and self._prev_comments:
+ expression.comments = self._prev_comments
def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )")
- if expression and self._prev_comment:
- expression.comment = self._prev_comment
+ if expression and self._prev_comments:
+ expression.comments = self._prev_comments
def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 51db2d4..4967231 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -130,18 +130,20 @@ class Step:
aggregations = []
sequence = itertools.count()
- for e in expression.expressions:
- aggregation = e.find(exp.AggFunc)
-
- if aggregation:
- projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
- aggregations.append(e)
- for operand in aggregation.unnest_operands():
+ def extract_agg_operands(expression):
+ for agg in expression.find_all(exp.AggFunc):
+ for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
+
+ for e in expression.expressions:
+ if e.find(exp.AggFunc):
+ projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
+ aggregations.append(e)
+ extract_agg_operands(e)
else:
projections.append(e)
@@ -156,6 +158,13 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
+
+ having = expression.args.get("having")
+
+ if having:
+ extract_agg_operands(having)
+ aggregate.condition = having.this
+
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
@@ -172,11 +181,6 @@ class Step:
aggregate.add_dependency(step)
step = aggregate
- having = expression.args.get("having")
-
- if having:
- step.condition = having.this
-
order = expression.args.get("order")
if order:
@@ -188,6 +192,17 @@ class Step:
step.projections = projections
+ if isinstance(expression, exp.Select) and expression.args.get("distinct"):
+ distinct = Aggregate()
+ distinct.source = step.name
+ distinct.name = step.name
+ distinct.group = {
+ e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
+ for e in projections or expression.expressions
+ }
+ distinct.add_dependency(step)
+ step = distinct
+
limit = expression.args.get("limit")
if limit:
@@ -231,6 +246,9 @@ class Step:
if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}")
+ if self.limit is not math.inf:
+ lines.append(f"{nested}Limit: {self.limit}")
+
if self.dependencies:
lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies:
@@ -258,12 +276,7 @@ class Scan(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression
- alias_ = expression.alias
-
- if not alias_:
- raise UnsupportedError(
- "Tables/Subqueries must be aliased. Run it through the optimizer"
- )
+ alias_ = expression.alias_or_name
if isinstance(expression, exp.Subquery):
table = expression.this
@@ -338,6 +351,9 @@ class Aggregate(Step):
lines.append(f"{indent}Group:")
for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
+ if self.condition:
+ lines.append(f"{indent}Having:")
+ lines.append(f"{indent} - {self.condition.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index ec8cd91..8a7a38e 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -81,6 +81,7 @@ class TokenType(AutoName):
BINARY = auto()
VARBINARY = auto()
JSON = auto()
+ JSONB = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@@ -91,6 +92,7 @@ class TokenType(AutoName):
NULLABLE = auto()
GEOMETRY = auto()
HLLSKETCH = auto()
+ HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
@@ -113,6 +115,7 @@ class TokenType(AutoName):
APPLY = auto()
ARRAY = auto()
ASC = auto()
+ ASOF = auto()
AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto()
BEGIN = auto()
@@ -130,6 +133,7 @@ class TokenType(AutoName):
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
+ COMPOUND = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
@@ -271,6 +275,7 @@ class TokenType(AutoName):
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
+ UNLOGGED = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
@@ -291,7 +296,7 @@ class TokenType(AutoName):
class Token:
- __slots__ = ("token_type", "text", "line", "col", "comment")
+ __slots__ = ("token_type", "text", "line", "col", "comments")
@classmethod
def number(cls, number: int) -> Token:
@@ -319,13 +324,13 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
- comment: t.Optional[str] = None,
+ comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
- self.comment = comment
+ self.comments = comments
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
@@ -452,6 +457,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
+ "COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
@@ -582,8 +588,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TRAILING": TokenType.TRAILING,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
- "UNPIVOT": TokenType.UNPIVOT,
+ "UNLOGGED": TokenType.UNLOGGED,
"UNNEST": TokenType.UNNEST,
+ "UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
@@ -686,12 +693,12 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
- "_comment",
+ "_comments",
"_char",
"_end",
"_peek",
"_prev_token_line",
- "_prev_token_comment",
+ "_prev_token_comments",
"_prev_token_type",
"_replace_backslash",
)
@@ -708,13 +715,13 @@ class Tokenizer(metaclass=_Tokenizer):
self._current = 0
self._line = 1
self._col = 1
- self._comment = None
+ self._comments: t.List[str] = []
self._char = None
self._end = None
self._peek = None
self._prev_token_line = -1
- self._prev_token_comment = None
+ self._prev_token_comments: t.List[str] = []
self._prev_token_type = None
def tokenize(self, sql: str) -> t.List[Token]:
@@ -767,7 +774,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
- self._prev_token_comment = self._comment
+ self._prev_token_comments = self._comments
self._prev_token_type = token_type # type: ignore
self.tokens.append(
Token(
@@ -775,10 +782,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._text if text is None else text,
self._line,
self._col,
- self._comment,
+ self._comments,
)
)
- self._comment = None
+ self._comments = []
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
@@ -857,22 +864,18 @@ class Tokenizer(metaclass=_Tokenizer):
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
- self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
+ self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
- self._comment = self._text[comment_start_size:] # type: ignore
-
- # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
- # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
+ self._comments.append(self._text[comment_start_size:]) # type: ignore
+ # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
+ # Multiple consecutive comments are preserved by appending them to the current comments list.
if comment_start_line == self._prev_token_line:
- if self._prev_token_comment is None:
- self.tokens[-1].comment = self._comment
- self._prev_token_comment = self._comment
-
- self._comment = None
+ self.tokens[-1].comments.extend(self._comments)
+ self._comments = []
return True
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 412b881..99949a1 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -2,6 +2,8 @@ from __future__ import annotations
import typing as t
+from sqlglot.helper import find_new_name
+
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
return expression
+def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
+ """
+ Convert SELECT DISTINCT ON statements to a subquery with a window function.
+
+ This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
+
+ Args:
+ expression: the expression that will be transformed.
+
+ Returns:
+ The transformed expression.
+ """
+ if (
+ isinstance(expression, exp.Select)
+ and expression.args.get("distinct")
+ and expression.args["distinct"].args.get("on")
+ and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
+ ):
+ distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
+ outer_selects = [e.copy() for e in expression.expressions]
+ nested = expression.copy()
+ nested.args["distinct"].pop()
+ row_number = find_new_name(expression.named_selects, "_row_number")
+ window = exp.Window(
+ this=exp.RowNumber(),
+ partition_by=distinct_cols,
+ )
+ order = nested.args.get("order")
+ if order:
+ window.set("order", order.copy())
+ order.pop()
+ window = exp.alias_(window, row_number)
+ nested.select(window, copy=False)
+ return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
@@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
+ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}