diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-07-06 07:28:12 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-07-06 07:28:12 +0000 |
commit | 374a0f6318bcf423b1b784d30b25a8327c65cb24 (patch) | |
tree | 9303a1cbdba85b5d9781ebef32eb1902d3790c99 /sqlglot | |
parent | Releasing debian version 16.7.7-1. (diff) | |
download | sqlglot-374a0f6318bcf423b1b784d30b25a8327c65cb24.tar.xz sqlglot-374a0f6318bcf423b1b784d30b25a8327c65cb24.zip |
Merging upstream version 17.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
30 files changed, 396 insertions, 182 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 739ec29..42801ac 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -94,7 +94,11 @@ def parse_one(sql: str, **opts) -> Expression: def parse_one( - sql: str, read: DialectType = None, into: t.Optional[exp.IntoType] = None, **opts + sql: str, + read: DialectType = None, + dialect: DialectType = None, + into: t.Optional[exp.IntoType] = None, + **opts, ) -> Expression: """ Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. @@ -102,6 +106,7 @@ def parse_one( Args: sql: the SQL code string to parse. read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read) into: the SQLGlot Expression to parse into. **opts: other `sqlglot.parser.Parser` options. @@ -109,7 +114,7 @@ def parse_one( The syntax tree for the first parsed statement. """ - dialect = Dialect.get_or_raise(read)() + dialect = Dialect.get_or_raise(read or dialect)() if into: result = dialect.parse_into(into, sql, **opts) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index c9d6c79..82162b4 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -163,6 +163,17 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: return expression +def _parse_timestamp(args: t.List) -> exp.StrToTime: + this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) + this.set("zone", seq_get(args, 2)) + return this + + +def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts: + expr_type = exp.DateFromParts if len(args) == 3 else exp.Date + return expr_type.from_arg_list(args) + + class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True @@ -203,8 +214,10 @@ class BigQuery(Dialect): while isinstance(parent, exp.Dot): parent = parent.parent - if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get( - "is_table" + if ( + not isinstance(parent, exp.UserDefinedFunction) + and not (isinstance(parent, exp.Table) and parent.db) + and not expression.meta.get("is_table") ): expression.set("this", expression.this.lower()) @@ -251,6 +264,7 @@ class BigQuery(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "DATE": _parse_date, "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "DATE_TRUNC": lambda args: exp.DateTrunc( @@ -264,9 +278,7 @@ class BigQuery(Dialect): "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), - "PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")( - [seq_get(args, 1), seq_get(args, 0)] - ), + "PARSE_TIMESTAMP": _parse_timestamp, "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -356,9 +368,11 @@ class BigQuery(Dialect): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False + QUERY_HINTS = False TABLE_HINTS = False LIMIT_FETCH = "LIMIT" RENAME_TABLE_WITH_DB = False + ESCAPE_LINE_BREAK = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -367,6 +381,7 @@ class BigQuery(Dialect): exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: _date_add_sql("DATE", "ADD"), + exp.DateFromParts: rename_func("DATE"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), @@ -397,7 +412,9 @@ class BigQuery(Dialect): ] ), exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", - exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", + exp.StrToTime: lambda self, e: self.func( + "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") + ), exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), @@ -548,10 +565,15 @@ class BigQuery(Dialect): } def attimezone_sql(self, expression: exp.AtTimeZone) -> str: - if not isinstance(expression.parent, exp.Cast): + parent = expression.parent + + # BigQuery allows CAST(.. AS {STRING|TIMESTAMP} [FORMAT <fmt> [AT TIME ZONE <tz>]]). + # Only the TIMESTAMP one should use the below conversion, when AT TIME ZONE is included. + if not isinstance(parent, exp.Cast) or not parent.to.is_type("text"): return self.func( "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone")) ) + return super().attimezone_sql(expression) def trycast_sql(self, expression: exp.TryCast) -> str: diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index efaf34c..9126c4b 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -109,10 +109,11 @@ class ClickHouse(Dialect): QUERY_MODIFIER_PARSERS = { **parser.Parser.QUERY_MODIFIER_PARSERS, - "settings": lambda self: self._parse_csv(self._parse_conjunction) - if self._match(TokenType.SETTINGS) - else None, - "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None, + TokenType.SETTINGS: lambda self: ( + "settings", + self._advance() or self._parse_csv(self._parse_conjunction), + ), + TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()), } def _parse_conjunction(self) -> t.Optional[exp.Expression]: @@ -155,9 +156,12 @@ class ClickHouse(Dialect): return this def _parse_table( - self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, ) -> t.Optional[exp.Expression]: - this = super()._parse_table(schema=schema, alias_tokens=alias_tokens) + this = super()._parse_table(schema=schema, joins=joins, alias_tokens=alias_tokens) if self._match(TokenType.FINAL): this = self.expression(exp.Final, this=this) @@ -273,6 +277,7 @@ class ClickHouse(Dialect): return None class Generator(generator.Generator): + QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index d258826..4fc93bf 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -98,7 +98,6 @@ class _Dialect(type): klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) - klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) dialect_properties = { **{ diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 3cca986..26d09ce 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -96,6 +96,7 @@ class Drill(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 093a01c..d7e5a43 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + date_trunc_to_time, datestrtodate_sql, format_time_lambda, no_comment_column_constraint_sql, @@ -38,6 +39,21 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" +# BigQuery -> DuckDB conversion for the DATE function +def _date_sql(self: generator.Generator, expression: exp.Date) -> str: + result = f"CAST({self.sql(expression, 'this')} AS DATE)" + zone = self.sql(expression, "zone") + + if zone: + date_str = self.func("STRFTIME", result, "'%d/%m/%Y'") + date_str = f"{date_str} || ' ' || {zone}" + + # This will create a TIMESTAMP with time zone information + result = self.func("STRPTIME", date_str, "'%d/%m/%Y %Z'") + + return result + + def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") @@ -131,6 +147,8 @@ class DuckDB(Dialect): "ARRAY_REVERSE_SORT": _sort_array_reverse, "DATEDIFF": _parse_date_diff, "DATE_DIFF": _parse_date_diff, + "DATE_TRUNC": date_trunc_to_time, + "DATETRUNC": date_trunc_to_time, "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000)) @@ -167,6 +185,7 @@ class DuckDB(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False LIMIT_FETCH = "LIMIT" STRUCT_DELIMITER = ("(", ")") RENAME_TABLE_WITH_DB = False @@ -188,7 +207,9 @@ class DuckDB(Dialect): exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DataType: _datatype_sql, + exp.Date: _date_sql, exp.DateAdd: _date_delta_sql, + exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateSub: _date_delta_sql, exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 6bca610..1abc0f4 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -273,13 +273,6 @@ class Hive(Dialect): ), } - QUERY_MODIFIER_PARSERS = { - **parser.Parser.QUERY_MODIFIER_PARSERS, - "cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), - "distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), - "sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY), - } - def _parse_types( self, check_func: bool = False, schema: bool = False ) -> t.Optional[exp.Expression]: @@ -319,6 +312,7 @@ class Hive(Dialect): TABLESAMPLE_SIZE_IS_PERCENT = True JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False INDEX_ON = "ON TABLE" TYPE_MAPPING = { @@ -429,10 +423,3 @@ class Hive(Dialect): expression = exp.DataType.build(expression.this) return super().datatype_sql(expression) - - def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: - return super().after_having_modifiers(expression) + [ - self.sql(expression, "distribute"), - self.sql(expression, "sort"), - self.sql(expression, "cluster"), - ] diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5f743ee..bae0e50 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -123,14 +123,15 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "CHARSET": TokenType.CHARACTER_SET, + "ENUM": TokenType.ENUM, "FORCE": TokenType.FORCE, "IGNORE": TokenType.IGNORE, "LONGBLOB": TokenType.LONGBLOB, "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "MEMBER OF": TokenType.MEMBER_OF, "SEPARATOR": TokenType.SEPARATOR, - "ENUM": TokenType.ENUM, "START": TokenType.BEGIN, "SIGNED": TokenType.BIGINT, "SIGNED INTEGER": TokenType.BIGINT, @@ -185,11 +186,26 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} + FUNC_TOKENS = { + *parser.Parser.FUNC_TOKENS, + TokenType.DATABASE, + TokenType.SCHEMA, + TokenType.VALUES, + } + TABLE_ALIAS_TOKENS = ( parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS ) + RANGE_PARSERS = { + **parser.Parser.RANGE_PARSERS, + TokenType.MEMBER_OF: lambda self, this: self.expression( + exp.JSONArrayContains, + this=this, + expression=self._parse_wrapped(self._parse_expression), + ), + } + FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), @@ -207,6 +223,10 @@ class MySQL(Dialect): this=self._parse_lambda(), separator=self._match(TokenType.SEPARATOR) and self._parse_field(), ), + # https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values + "VALUES": lambda self: self.expression( + exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()] + ), } STATEMENT_PARSERS = { @@ -399,6 +419,8 @@ class MySQL(Dialect): NULL_ORDERING_SUPPORTED = False JOIN_HINTS = False TABLE_HINTS = True + DUPLICATE_KEY_UPDATE_WITH_SET = False + QUERY_HINT_SEP = " " TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -445,6 +467,9 @@ class MySQL(Dialect): LIMIT_FETCH = "LIMIT" + def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str: + return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead.""" if expression.to.this == exp.DataType.Type.BIGINT: diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 8d35e92..2b77ef9 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -121,7 +121,6 @@ class Oracle(Dialect): "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), exp.Group: transforms.preprocess([transforms.unalias_group]), - exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, exp.Coalesce: rename_func("NVL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 766b584..6d78a07 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -183,6 +183,29 @@ def _to_timestamp(args: t.List) -> exp.Expression: return format_time_lambda(exp.StrToTime, "postgres")(args) +def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: + """Remove table refs from columns in when statements.""" + if isinstance(expression, exp.Merge): + alias = expression.this.args.get("alias") + + normalize = lambda identifier: Postgres.normalize_identifier(identifier).name + + targets = {normalize(expression.this.this)} + + if alias: + targets.add(normalize(alias.this)) + + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.name) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node, + copy=False, + ) + + return expression + + class Postgres(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_large" @@ -315,6 +338,7 @@ class Postgres(Dialect): LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { @@ -352,7 +376,7 @@ class Postgres(Dialect): exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), - exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]), + exp.Merge: transforms.preprocess([_remove_target_from_merge]), exp.Pivot: no_pivot_sql, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 24c439b..1721588 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -232,6 +232,7 @@ class Presto(Dialect): INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False IS_BOOL_ALLOWED = False STRUCT_DELIMITER = ("(", ")") diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 87be42c..09edd55 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -86,6 +86,7 @@ class Redshift(Postgres): class Generator(Postgres.Generator): LOCKING_READS_SUPPORTED = False RENAME_TABLE_WITH_DB = False + QUERY_HINTS = False TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 34e4dd0..19924cd 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -326,6 +326,7 @@ class Snowflake(Dialect): SINGLE_STRING_INTERVAL = True JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 3720b8d..afe2482 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -173,6 +173,8 @@ class Spark2(Hive): return pivot_column_names(aggregations, dialect="spark") class Generator(Hive.Generator): + QUERY_HINTS = True + TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", @@ -203,7 +205,6 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 519e62a..5ded6df 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -77,6 +77,7 @@ class SQLite(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 67ef76b..33ec7e1 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -8,6 +8,7 @@ class Tableau(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index d9a5417..4e8ffb4 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -121,7 +121,7 @@ class Teradata(Dialect): exp.Update, **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "from": self._parse_from(modifiers=True), + "from": self._parse_from(joins=True), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "where": self._parse_where(), @@ -140,6 +140,7 @@ class Teradata(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False + QUERY_HINTS = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index f671630..92bb755 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -60,10 +60,10 @@ def _format_time_lambda( assert len(args) == 2 return exp_class( - this=args[1], + this=exp.cast(args[1], "datetime"), format=exp.Literal.string( format_time( - args[0].name, + args[0].name.lower(), {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.TIME_MAPPING, @@ -467,6 +467,8 @@ class TSQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + LIMIT_IS_TOP = True + QUERY_HINTS = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -482,6 +484,7 @@ class TSQL(Dialect): exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), + exp.Extract: rename_func("DATEPART"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), exp.Max: max_or_greatest, diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index 630cb65..d7952c1 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -41,11 +41,13 @@ class Context: def table(self) -> Table: if self._table is None: self._table = list(self.tables.values())[0] + for other in self.tables.values(): if self._table.columns != other.columns: raise Exception(f"Columns are different.") if len(self._table.rows) != len(other.rows): raise Exception(f"Rows are different.") + return self._table def add_columns(self, *columns: str) -> None: diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 5300224..9f63100 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -100,9 +100,19 @@ def substring(this, start=None, length=None): @null_if_any def cast(this, to): if to == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(this) - if to == exp.DataType.Type.DATETIME: - return datetime.datetime.fromisoformat(this) + if isinstance(this, datetime.datetime): + return this.date() + if isinstance(this, datetime.date): + return this + if isinstance(this, str): + return datetime.date.fromisoformat(this) + if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP): + if isinstance(this, datetime.datetime): + return this + if isinstance(this, datetime.date): + return datetime.datetime(this.year, this.month, this.day) + if isinstance(this, str): + return datetime.datetime.fromisoformat(this) if to == exp.DataType.Type.BOOLEAN: return bool(this) if to in exp.DataType.TEXT_TYPES: @@ -111,7 +121,7 @@ def cast(this, to): return float(this) if to in exp.DataType.NUMERIC_TYPES: return int(this) - raise NotImplementedError(f"Casting to '{to}' not implemented.") + raise NotImplementedError(f"Casting {this} to '{to}' not implemented.") def ordered(this, desc, nulls_first): @@ -153,6 +163,7 @@ ENV = { "CONCAT": null_if_any(lambda *args: "".join(args)), "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)), "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days), "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)), "DIV": null_if_any(lambda e, this: e / this), "DOT": null_if_any(lambda e, this: e[this]), diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 34a380e..d2ae79d 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -417,7 +417,9 @@ class Python(Dialect): exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", - exp.Is: lambda self, e: self.binary(e, "is"), + exp.Is: lambda self, e: self.binary(e, "==") + if isinstance(e.this, exp.Literal) + else self.binary(e, "is"), exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index cdb93db..fdf02c8 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -102,13 +102,10 @@ class Expression(metaclass=_Expression): @property def hashable_args(self) -> t.Any: - args = (self.args.get(k) for k in self.arg_types) - - return tuple( - (tuple(_norm_arg(a) for a in arg) if arg else None) - if type(arg) is list - else (_norm_arg(arg) if arg is not None and arg is not False else None) - for arg in args + return frozenset( + (k, tuple(_norm_arg(a) for a in v) if type(v) is list else _norm_arg(v)) + for k, v in self.args.items() + if not (v is None or v is False or (type(v) is list and not v)) ) def __hash__(self) -> int: @@ -1304,6 +1301,7 @@ class Delete(Expression): "where": False, "returning": False, "limit": False, + "tables": False, # Multiple-Table Syntax (MySQL) } def delete( @@ -1490,9 +1488,7 @@ class Identifier(Expression): @property def hashable_args(self) -> t.Any: - if self.quoted and any(char.isupper() for char in self.this): - return (self.this, self.quoted) - return self.this.lower() + return (self.this, self.quoted) @property def output_name(self) -> str: @@ -1525,6 +1521,7 @@ class Insert(Expression): "partition": False, "alternative": False, "where": False, + "ignore": False, } def with_( @@ -1620,6 +1617,7 @@ class Group(Expression): "cube": False, "rollup": False, "totals": False, + "all": False, } @@ -4135,9 +4133,9 @@ class DateToDi(Func): pass +# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date class Date(Func): - arg_types = {"expressions": True} - is_var_len_args = True + arg_types = {"this": True, "zone": False} class Day(Func): @@ -4245,6 +4243,11 @@ class JSONFormat(Func): _sql_names = ["JSON_FORMAT"] +# https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of +class JSONArrayContains(Binary, Predicate, Func): + _sql_names = ["JSON_ARRAY_CONTAINS"] + + class Least(Func): arg_types = {"expressions": False} is_var_len_args = True @@ -4475,7 +4478,7 @@ class StrToDate(Func): class StrToTime(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": True, "format": True, "zone": False} # Spark allows unix_timestamp() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index a41af12..1ce2aaa 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -140,9 +140,21 @@ class Generator: # Whether or not table hints should be generated TABLE_HINTS = True + # Whether or not query hints should be generated + QUERY_HINTS = True + + # What kind of separator to use for query hints + QUERY_HINT_SEP = ", " + # Whether or not comparing against booleans (e.g. x IS TRUE) is supported IS_BOOL_ALLOWED = True + # Whether or not to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement + DUPLICATE_KEY_UPDATE_WITH_SET = True + + # Whether or not to generate the limit as TOP <value> instead of LIMIT <value> + LIMIT_IS_TOP = False + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") @@ -268,6 +280,7 @@ class Generator: STRICT_STRING_CONCAT = False NORMALIZE_FUNCTIONS: bool | str = "upper" NULL_ORDERING = "nulls_are_small" + ESCAPE_LINE_BREAK = False can_identify: t.Callable[[str, str | bool], bool] @@ -286,8 +299,6 @@ class Generator: HEX_END: t.Optional[str] = None BYTE_START: t.Optional[str] = None BYTE_END: t.Optional[str] = None - RAW_START: t.Optional[str] = None - RAW_END: t.Optional[str] = None __slots__ = ( "pretty", @@ -486,7 +497,10 @@ class Generator: return expression if key: - return self.sql(expression.args.get(key)) + value = expression.args.get(key) + if value: + return self.sql(value) + return "" if self._cache is not None: expression_id = hash(expression) @@ -779,10 +793,7 @@ class Generator: return this def rawstring_sql(self, expression: exp.RawString) -> str: - string = expression.this - if self.RAW_START: - return f"{self.RAW_START}{self.escape_str(expression.this)}{self.RAW_END}" - string = self.escape_str(string.replace("\\", "\\\\")) + string = self.escape_str(expression.this.replace("\\", "\\\\")) return f"{self.QUOTE_START}{string}{self.QUOTE_END}" def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: @@ -818,15 +829,14 @@ class Generator: def delete_sql(self, expression: exp.Delete) -> str: this = self.sql(expression, "this") this = f" FROM {this}" if this else "" - using_sql = ( - f" USING {self.expressions(expression, key='using', sep=', USING ')}" - if expression.args.get("using") - else "" - ) - where_sql = self.sql(expression, "where") + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + where = self.sql(expression, "where") returning = self.sql(expression, "returning") limit = self.sql(expression, "limit") - sql = f"DELETE{this}{using_sql}{where_sql}{returning}{limit}" + tables = self.expressions(expression, key="tables") + tables = f" {tables}" if tables else "" + sql = f"DELETE{tables}{this}{using}{where}{returning}{limit}" return self.prepend_ctes(expression, sql) def drop_sql(self, expression: exp.Drop) -> str: @@ -867,9 +877,11 @@ class Generator: return f"{this} FILTER({where})" def hint_sql(self, expression: exp.Hint) -> str: - if self.sql(expression, "this"): + if not self.QUERY_HINTS: self.unsupported("Hints are not supported") - return "" + return "" + + return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */" def index_sql(self, expression: exp.Index) -> str: unique = "UNIQUE " if expression.args.get("unique") else "" @@ -1109,6 +1121,8 @@ class Generator: alternative = expression.args.get("alternative") alternative = f" OR {alternative}" if alternative else "" + ignore = " IGNORE" if expression.args.get("ignore") else "" + this = f"{this} {self.sql(expression, 'this')}" exists = " IF EXISTS" if expression.args.get("exists") else "" @@ -1120,7 +1134,7 @@ class Generator: expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" conflict = self.sql(expression, "conflict") returning = self.sql(expression, "returning") - sql = f"INSERT{alternative}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}" + sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1147,8 +1161,9 @@ class Generator: do = "" if expression.args.get("duplicate") else " DO " nothing = "NOTHING" if expression.args.get("nothing") else "" expressions = self.expressions(expression, flat=True) + set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else "" if expressions: - expressions = f"UPDATE SET {expressions}" + expressions = f"UPDATE {set_keyword}{expressions}" return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}" def returning_sql(self, expression: exp.Returning) -> str: @@ -1195,7 +1210,7 @@ class Generator: hints = f" {hints}" if hints and self.TABLE_HINTS else "" pivots = self.expressions(expression, key="pivots", sep=" ", flat=True) pivots = f" {pivots}" if pivots else "" - joins = self.expressions(expression, key="joins", sep="") + joins = self.expressions(expression, key="joins", sep="", skip_first=True) laterals = self.expressions(expression, key="laterals", sep="") system_time = expression.args.get("system_time") system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" @@ -1287,6 +1302,10 @@ class Generator: def group_sql(self, expression: exp.Group) -> str: group_by = self.op_expressions("GROUP BY", expression) + + if expression.args.get("all"): + return f"{group_by} ALL" + grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = ( f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" @@ -1379,7 +1398,7 @@ class Generator: alias = f" AS {alias}" if alias else "" return f"LATERAL {this}{alias}" - def limit_sql(self, expression: exp.Limit) -> str: + def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: this = self.sql(expression, "this") args = ", ".join( sql @@ -1389,7 +1408,7 @@ class Generator: ) if sql ) - return f"{this}{self.seg('LIMIT')} {args}" + return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") @@ -1441,7 +1460,9 @@ class Generator: def escape_str(self, text: str) -> str: text = text.replace(self.QUOTE_END, self._escaped_quote_end) - if self.pretty: + if self.ESCAPE_LINE_BREAK: + text = text.replace("\n", "\\n") + elif self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) return text @@ -1544,6 +1565,9 @@ class Generator: def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit") + # If the limit is generated as TOP, we need to ensure it's not generated twice + with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP + if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): limit = exp.Limit(expression=limit.args.get("count")) elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): @@ -1551,6 +1575,12 @@ class Generator: fetch = isinstance(limit, exp.Fetch) + offset_limit_modifiers = ( + self.offset_limit_modifiers(expression, fetch, limit) + if with_offset_limit_modifiers + else [] + ) + return csv( *sqls, *[self.sql(join) for join in expression.args.get("joins") or []], @@ -1561,7 +1591,7 @@ class Generator: self.sql(expression, "having"), *self.after_having_modifiers(expression), self.sql(expression, "order"), - *self.offset_limit_modifiers(expression, fetch, limit), + *offset_limit_modifiers, *self.after_limit_modifiers(expression), sep="", ) @@ -1580,6 +1610,9 @@ class Generator: self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) if expression.args.get("windows") else "", + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), ] def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: @@ -1592,6 +1625,13 @@ class Generator: distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" kind = self.sql(expression, "kind").upper() + limit = expression.args.get("limit") + top = ( + self.limit_sql(limit, top=True) + if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP + else "" + ) + expressions = self.expressions(expression) if kind: @@ -1618,7 +1658,7 @@ class Generator: expressions = f"{self.sep()}{expressions}" if expressions else expressions sql = self.query_modifiers( expression, - f"SELECT{hint}{distinct}{kind}{expressions}", + f"SELECT{top}{hint}{distinct}{kind}{expressions}", self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) @@ -2288,6 +2328,7 @@ class Generator: sqls: t.Optional[t.List[str]] = None, flat: bool = False, indent: bool = True, + skip_first: bool = False, sep: str = ", ", prefix: str = "", ) -> str: @@ -2321,7 +2362,7 @@ class Generator: result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) - return self.indent(result_sql, skip_first=False) if indent else result_sql + return self.indent(result_sql, skip_first=skip_first) if indent else result_sql def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: flat = flat or isinstance(expression.parent, exp.Properties) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 0fc5f4c..e7cb80b 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -105,6 +105,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.CurrentDate, exp.Date, exp.DateAdd, + exp.DateFromParts, exp.DateStrToDate, exp.DateSub, exp.DateTrunc, diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 8c3f599..435585c 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -220,25 +220,6 @@ def _expand_group_by(scope: Scope): group.set("expressions", _expand_positional_references(scope, group.expressions)) expression.set("group", group) - # group by expressions cannot be simplified, for example - # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 - # the projection must exactly match the group by key - groups = set(group.expressions) - group.meta["final"] = True - - for e in expression.selects: - for node, *_ in e.walk(): - if node in groups: - e.meta["final"] = True - break - - having = expression.args.get("having") - if having: - for node, *_ in having.walk(): - if node in groups: - having.meta["final"] = True - break - def _expand_order_by(scope: Scope, resolver: Resolver): order = scope.expression.args.get("order") diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 1a2d82c..e247f58 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -8,6 +8,9 @@ from sqlglot import exp from sqlglot.generator import cached_generator from sqlglot.helper import first, while_changing +# Final means that an expression should not be simplified +FINAL = "final" + def simplify(expression): """ @@ -27,8 +30,29 @@ def simplify(expression): generate = cached_generator() + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + for group in expression.find_all(exp.Group): + select = group.parent + groups = set(group.expressions) + group.meta[FINAL] = True + + for e in select.selects: + for node, *_ in e.walk(): + if node in groups: + e.meta[FINAL] = True + break + + having = select.args.get("having") + if having: + for node, *_ in having.walk(): + if node in groups: + having.meta[FINAL] = True + break + def _simplify(expression, root=True): - if expression.meta.get("final"): + if expression.meta.get(FINAL): return expression node = expression node = rewrite_between(node) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 79e7cac..f7fd6ba 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -737,19 +737,29 @@ class Parser(metaclass=_Parser): } QUERY_MODIFIER_PARSERS = { - "joins": lambda self: list(iter(self._parse_join, None)), - "laterals": lambda self: list(iter(self._parse_lateral, None)), - "match": lambda self: self._parse_match_recognize(), - "where": lambda self: self._parse_where(), - "group": lambda self: self._parse_group(), - "having": lambda self: self._parse_having(), - "qualify": lambda self: self._parse_qualify(), - "windows": lambda self: self._parse_window_clause(), - "order": lambda self: self._parse_order(), - "limit": lambda self: self._parse_limit(), - "offset": lambda self: self._parse_offset(), - "locks": lambda self: self._parse_locks(), - "sample": lambda self: self._parse_table_sample(as_modifier=True), + TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()), + TokenType.WHERE: lambda self: ("where", self._parse_where()), + TokenType.GROUP_BY: lambda self: ("group", self._parse_group()), + TokenType.HAVING: lambda self: ("having", self._parse_having()), + TokenType.QUALIFY: lambda self: ("qualify", self._parse_qualify()), + TokenType.WINDOW: lambda self: ("windows", self._parse_window_clause()), + TokenType.ORDER_BY: lambda self: ("order", self._parse_order()), + TokenType.LIMIT: lambda self: ("limit", self._parse_limit()), + TokenType.FETCH: lambda self: ("limit", self._parse_limit()), + TokenType.OFFSET: lambda self: ("offset", self._parse_offset()), + TokenType.FOR: lambda self: ("locks", self._parse_locks()), + TokenType.LOCK: lambda self: ("locks", self._parse_locks()), + TokenType.TABLE_SAMPLE: lambda self: ("sample", self._parse_table_sample(as_modifier=True)), + TokenType.USING: lambda self: ("sample", self._parse_table_sample(as_modifier=True)), + TokenType.CLUSTER_BY: lambda self: ( + "cluster", + self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), + ), + TokenType.DISTRIBUTE_BY: lambda self: ( + "distribute", + self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), + ), + TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)), } SET_PARSERS = { @@ -1679,6 +1689,7 @@ class Parser(metaclass=_Parser): def _parse_insert(self) -> exp.Insert: overwrite = self._match(TokenType.OVERWRITE) + ignore = self._match(TokenType.IGNORE) local = self._match_text_seq("LOCAL") alternative = None @@ -1709,6 +1720,7 @@ class Parser(metaclass=_Parser): returning=self._parse_returning(), overwrite=overwrite, alternative=alternative, + ignore=ignore, ) def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: @@ -1734,7 +1746,8 @@ class Parser(metaclass=_Parser): nothing = True else: self._match(TokenType.UPDATE) - expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) + self._match(TokenType.SET) + expressions = self._parse_csv(self._parse_equality) return self.expression( exp.OnConflict, @@ -1805,12 +1818,17 @@ class Parser(metaclass=_Parser): return self._parse_as_command(self._prev) def _parse_delete(self) -> exp.Delete: - self._match(TokenType.FROM) + # This handles MySQL's "Multiple-Table Syntax" + # https://dev.mysql.com/doc/refman/8.0/en/delete.html + tables = None + if not self._match(TokenType.FROM, advance=False): + tables = self._parse_csv(self._parse_table) or None return self.expression( exp.Delete, - this=self._parse_table(), - using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), + tables=tables, + this=self._match(TokenType.FROM) and self._parse_table(joins=True), + using=self._match(TokenType.USING) and self._parse_table(joins=True), where=self._parse_where(), returning=self._parse_returning(), limit=self._parse_limit(), @@ -1822,7 +1840,7 @@ class Parser(metaclass=_Parser): **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), - "from": self._parse_from(modifiers=True), + "from": self._parse_from(joins=True), "where": self._parse_where(), "returning": self._parse_returning(), "limit": self._parse_limit(), @@ -1875,7 +1893,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Tuple, expressions=expressions) # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. - # Source: https://prestodb.io/docs/current/sql/values.html + # https://prestodb.io/docs/current/sql/values.html return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) def _parse_select( @@ -1917,7 +1935,7 @@ class Parser(metaclass=_Parser): self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv(self._parse_expression) + expressions = self._parse_expressions() this = self.expression( exp.Select, @@ -2034,20 +2052,31 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: if isinstance(this, self.MODIFIABLES): - for key, parser in self.QUERY_MODIFIER_PARSERS.items(): - expression = parser(self) - - if expression: - if key == "limit": - offset = expression.args.pop("offset", None) - if offset: - this.set("offset", exp.Offset(expression=offset)) - this.set(key, expression) + for join in iter(self._parse_join, None): + this.append("joins", join) + for lateral in iter(self._parse_lateral, None): + this.append("laterals", lateral) + + while True: + if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False): + parser = self.QUERY_MODIFIER_PARSERS[self._curr.token_type] + key, expression = parser(self) + + if expression: + this.set(key, expression) + if key == "limit": + offset = expression.args.pop("offset", None) + if offset: + this.set("offset", exp.Offset(expression=offset)) + continue + break return this def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): - hints = self._parse_csv(self._parse_function) + hints = [] + for hint in iter(lambda: self._parse_csv(self._parse_function), []): + hints.extend(hint) if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") @@ -2069,18 +2098,13 @@ class Parser(metaclass=_Parser): ) def _parse_from( - self, modifiers: bool = False, skip_from_token: bool = False + self, joins: bool = False, skip_from_token: bool = False ) -> t.Optional[exp.From]: if not skip_from_token and not self._match(TokenType.FROM): return None - comments = self._prev_comments - this = self._parse_table() - return self.expression( - exp.From, - comments=comments, - this=self._parse_query_modifiers(this) if modifiers else this, + exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) ) def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: @@ -2091,9 +2115,7 @@ class Parser(metaclass=_Parser): partition = self._parse_partition_by() order = self._parse_order() - measures = ( - self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None - ) + measures = self._parse_expressions() if self._match_text_seq("MEASURES") else None if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): rows = exp.var("ONE ROW PER MATCH") @@ -2259,6 +2281,18 @@ class Parser(metaclass=_Parser): kwargs["on"] = self._parse_conjunction() elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() + elif not (kind and kind.token_type == TokenType.CROSS): + index = self._index + joins = self._parse_joins() + + if joins and self._match(TokenType.ON): + kwargs["on"] = self._parse_conjunction() + elif joins and self._match(TokenType.USING): + kwargs["using"] = self._parse_wrapped_id_vars() + else: + joins = None + self._retreat(index) + kwargs["this"].set("joins", joins) return self.expression(exp.Join, **kwargs) @@ -2363,7 +2397,10 @@ class Parser(metaclass=_Parser): ) def _parse_table( - self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -2407,6 +2444,10 @@ class Parser(metaclass=_Parser): table_sample.set("this", this) this = table_sample + if joins: + for join in iter(self._parse_join, None): + this.append("joins", join) + return this def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: @@ -2507,8 +2548,11 @@ class Parser(metaclass=_Parser): kind=kind, ) - def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: - return list(iter(self._parse_pivot, None)) + def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: + return list(iter(self._parse_pivot, None)) or None + + def _parse_joins(self) -> t.Optional[t.List[exp.Join]]: + return list(iter(self._parse_join, None)) or None # https://duckdb.org/docs/sql/statements/pivot def _parse_simplified_pivot(self) -> exp.Pivot: @@ -2603,6 +2647,9 @@ class Parser(metaclass=_Parser): elements = defaultdict(list) + if self._match(TokenType.ALL): + return self.expression(exp.Group, all=True) + while True: expressions = self._parse_csv(self._parse_conjunction) if expressions: @@ -3171,7 +3218,7 @@ class Parser(metaclass=_Parser): if query: expressions = [query] else: - expressions = self._parse_csv(self._parse_expression) + expressions = self._parse_expressions() this = self._parse_query_modifiers(seq_get(expressions, 0)) @@ -3536,11 +3583,7 @@ class Parser(metaclass=_Parser): return None expressions = None - this = self._parse_id_var() - - if self._match(TokenType.L_PAREN, advance=False): - expressions = self._parse_wrapped_id_vars() - + this = self._parse_table(schema=True) options = self._parse_key_constraint_options() return self.expression(exp.Reference, this=this, expressions=expressions, options=options) @@ -3706,21 +3749,27 @@ class Parser(metaclass=_Parser): if self._match(TokenType.CHARACTER_SET): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) elif self._match(TokenType.FORMAT): - fmt = self._parse_at_time_zone(self._parse_string()) + fmt_string = self._parse_string() + fmt = self._parse_at_time_zone(fmt_string) if to.this in exp.DataType.TEMPORAL_TYPES: - return self.expression( + this = self.expression( exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, this=this, format=exp.Literal.string( format_time( - fmt.this if fmt else "", + fmt_string.this if fmt_string else "", self.FORMAT_MAPPING or self.TIME_MAPPING, self.FORMAT_TRIE or self.TIME_TRIE, ) ), ) + if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime): + this.set("zone", fmt.args["zone"]) + + return this + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt) def _parse_concat(self) -> t.Optional[exp.Expression]: @@ -4223,7 +4272,7 @@ class Parser(metaclass=_Parser): return None if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_expression) - return self._parse_csv(self._parse_expression) + return self._parse_expressions() def _parse_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA @@ -4273,6 +4322,9 @@ class Parser(metaclass=_Parser): self._match_r_paren() return parse_result + def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]: + return self._parse_csv(self._parse_expression) + def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]: return self._parse_select() or self._parse_set_operations( self._parse_expression() if alias else self._parse_conjunction() diff --git a/sqlglot/planner.py b/sqlglot/planner.py index f246702..07ee739 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -23,9 +23,11 @@ class Plan: while nodes: node = nodes.pop() dag[node] = set() + for dep in node.dependencies: dag[node].add(dep) nodes.add(dep) + self._dag = dag return self._dag @@ -128,15 +130,22 @@ class Step: agg_funcs = tuple(expression.find_all(exp.AggFunc)) if agg_funcs: aggregations.add(expression) + for agg in agg_funcs: for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = next_operand_name() + operand.replace(exp.column(operands[operand], quoted=True)) + return bool(agg_funcs) + def set_ops_and_aggs(step): + step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) + step.aggregations = list(aggregations) + for e in expression.expressions: if e.find(exp.AggFunc): projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) @@ -164,10 +173,7 @@ class Step: else: aggregate.condition = having.this - aggregate.operands = tuple( - alias(operand, alias_) for operand, alias_ in operands.items() - ) - aggregate.aggregations = list(aggregations) + set_ops_and_aggs(aggregate) # give aggregates names and replace projections with references to them aggregate.group = { @@ -178,13 +184,14 @@ class Step: for k, v in aggregate.group.items(): intermediate[v] = k if isinstance(v, exp.Column): - intermediate[v.alias_or_name] = k + intermediate[v.name] = k for projection in projections: for node, *_ in projection.walk(): name = intermediate.get(node) if name: node.replace(exp.column(name, step.name)) + if aggregate.condition: for node, *_ in aggregate.condition.walk(): name = intermediate.get(node) or intermediate.get(node.name) @@ -197,6 +204,13 @@ class Step: order = expression.args.get("order") if order: + if isinstance(step, Aggregate): + for i, ordered in enumerate(order.expressions): + if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)): + ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True)) + + set_ops_and_aggs(aggregate) + sort = Sort() sort.name = step.name sort.key = order.expressions @@ -340,7 +354,10 @@ class Join(Step): def _to_s(self, indent: str) -> t.List[str]: lines = [] for name, join in self.joins.items(): - lines.append(f"{indent}{name}: {join['side']}") + lines.append(f"{indent}{name}: {join['side'] or 'INNER'}") + join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or [])) + if join_key: + lines.append(f"{indent}Key: {join_key}") if join.get("condition"): lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore return lines diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 79f7a65..999bde2 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -239,6 +239,7 @@ class TokenType(AutoName): LOCK = auto() MAP = auto() MATCH_RECOGNIZE = auto() + MEMBER_OF = auto() MERGE = auto() MOD = auto() NATURAL = auto() @@ -944,8 +945,6 @@ class Tokenizer(metaclass=_Tokenizer): char = "" chars = " " - word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word - if not word: if self._char in self.SINGLE_TOKENS: self._add(self.SINGLE_TOKENS[self._char], text=self._char) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 33a1bc0..1e6cfc8 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -124,7 +124,9 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr other expressions. This transforms removes the precision from parameterized types in expressions. """ for node in expression.find_all(exp.DataType): - node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) + node.set( + "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)] + ) return expression @@ -215,25 +217,6 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression: return expression -def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: - """Remove table refs from columns in when statements.""" - if isinstance(expression, exp.Merge): - alias = expression.this.args.get("alias") - targets = {expression.this.this} - if alias: - targets.add(alias.this) - - for when in expression.expressions: - when.transform( - lambda node: exp.column(node.name) - if isinstance(node, exp.Column) and node.args.get("table") in targets - else node, - copy=False, - ) - - return expression - - def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: if ( isinstance(expression, exp.WithinGroup) |