From 9b39dac84e82bf473216939e50b8836170f01d23 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 29 Jun 2023 15:02:29 +0200 Subject: Merging upstream version 16.7.3. Signed-off-by: Daniel Baumann --- sqlglot/dataframe/sql/functions.py | 2 +- sqlglot/dialects/bigquery.py | 205 +++++++++++++++++++++++++++++++++-- sqlglot/dialects/dialect.py | 5 + sqlglot/dialects/mysql.py | 15 +++ sqlglot/dialects/postgres.py | 10 ++ sqlglot/dialects/presto.py | 13 ++- sqlglot/dialects/redshift.py | 8 +- sqlglot/dialects/snowflake.py | 31 ++++-- sqlglot/dialects/spark.py | 1 + sqlglot/dialects/spark2.py | 4 +- sqlglot/dialects/sqlite.py | 4 +- sqlglot/dialects/tsql.py | 1 + sqlglot/executor/context.py | 6 +- sqlglot/executor/python.py | 6 +- sqlglot/expressions.py | 28 ++++- sqlglot/generator.py | 22 +++- sqlglot/optimizer/qualify.py | 2 +- sqlglot/optimizer/qualify_columns.py | 114 +++++++++++++------ sqlglot/optimizer/simplify.py | 2 + sqlglot/parser.py | 110 ++++++++++++------- sqlglot/planner.py | 40 +++++-- sqlglot/transforms.py | 3 +- 22 files changed, 500 insertions(+), 132 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 71385aa..bdc1fb4 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -1119,7 +1119,7 @@ def map_entries(col: ColumnOrName) -> Column: def map_from_entries(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES") + return Column.invoke_expression_over_column(col, expression.MapFromEntries) def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 52d4a88..8786063 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import re import typing as t @@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType +logger = logging.getLogger("sqlglot") + def _date_add_sql( data_type: str, kind: str @@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: return expression +# https://issuetracker.google.com/issues/162294746 +# workaround for bigquery bug when grouping by an expression and then ordering +# WITH x AS (SELECT 1 y) +# SELECT y + 1 z +# FROM x +# GROUP BY x + 1 +# ORDER by z +def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + group = expression.args.get("group") + order = expression.args.get("order") + + if group and order: + aliases = { + select.this: select.args["alias"] + for select in expression.selects + if isinstance(select, exp.Alias) + } + + for e in group.expressions: + alias = aliases.get(e) + + if alias: + e.replace(exp.column(alias)) + + return expression + + +def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: + """BigQuery doesn't allow column names when defining a CTE, so we try to push them down.""" + if isinstance(expression, exp.CTE) and expression.alias_column_names: + cte_query = expression.this + + if cte_query.is_star: + logger.warning( + "Can't push down CTE column names for star queries. Run the query through" + " the optimizer or use 'qualify' to expand the star projections first." + ) + return expression + + column_names = expression.alias_column_names + expression.args["alias"].set("columns", None) + + for name, select in zip(column_names, cte_query.selects): + to_replace = select + + if isinstance(select, exp.Alias): + select = select.this + + # Inner aliases are shadowed by the CTE column names + to_replace.replace(exp.alias_(select, name)) + + return expression + + class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + # bigquery udfs are case sensitive + NORMALIZE_FUNCTIONS = False + TIME_MAPPING = { "%D": "%m/%d/%y", } @@ -135,12 +196,16 @@ class BigQuery(Dialect): # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). # The following check is essentially a heuristic to detect tables based on whether or # not they're qualified. - if ( - isinstance(expression, exp.Identifier) - and not (isinstance(expression.parent, exp.Table) and expression.parent.db) - and not expression.meta.get("is_table") - ): - expression.set("this", expression.this.lower()) + if isinstance(expression, exp.Identifier): + parent = expression.parent + + while isinstance(parent, exp.Dot): + parent = parent.parent + + if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get( + "is_table" + ): + expression.set("this", expression.this.lower()) return expression @@ -298,10 +363,8 @@ class BigQuery(Dialect): **generator.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySize: rename_func("ARRAY_LENGTH"), - exp.AtTimeZone: lambda self, e: self.func( - "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) - ), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), + exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), @@ -325,7 +388,12 @@ class BigQuery(Dialect): ), exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.Select: transforms.preprocess( - [_unqualify_unnest, transforms.eliminate_distinct_on] + [ + transforms.explode_to_unnest, + _unqualify_unnest, + transforms.eliminate_distinct_on, + _alias_ordered_group, + ] ), exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", @@ -334,7 +402,6 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -378,7 +445,121 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"} + # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords + RESERVED_KEYWORDS = { + *generator.Generator.RESERVED_KEYWORDS, + "all", + "and", + "any", + "array", + "as", + "asc", + "assert_rows_modified", + "at", + "between", + "by", + "case", + "cast", + "collate", + "contains", + "create", + "cross", + "cube", + "current", + "default", + "define", + "desc", + "distinct", + "else", + "end", + "enum", + "escape", + "except", + "exclude", + "exists", + "extract", + "false", + "fetch", + "following", + "for", + "from", + "full", + "group", + "grouping", + "groups", + "hash", + "having", + "if", + "ignore", + "in", + "inner", + "intersect", + "interval", + "into", + "is", + "join", + "lateral", + "left", + "like", + "limit", + "lookup", + "merge", + "natural", + "new", + "no", + "not", + "null", + "nulls", + "of", + "on", + "or", + "order", + "outer", + "over", + "partition", + "preceding", + "proto", + "qualify", + "range", + "recursive", + "respect", + "right", + "rollup", + "rows", + "select", + "set", + "some", + "struct", + "tablesample", + "then", + "to", + "treat", + "true", + "unbounded", + "union", + "unnest", + "using", + "when", + "where", + "window", + "with", + "within", + } + + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: + if not isinstance(expression.parent, exp.Cast): + return self.func( + "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone")) + ) + return super().attimezone_sql(expression) + + def trycast_sql(self, expression: exp.TryCast) -> str: + return self.cast_sql(expression, safe_prefix="SAFE_") + + def cte_sql(self, expression: exp.CTE) -> str: + if expression.alias_column_names: + self.unsupported("Column names in CTE definition are not supported.") + return super().cte_sql(expression) def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0e25b9b..d258826 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -388,6 +388,11 @@ def no_comment_column_constraint_sql( return "" +def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: + self.unsupported("MAP_FROM_ENTRIES unsupported") + return "" + + def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1dd2096..5f743ee 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -132,6 +132,10 @@ class MySQL(Dialect): "SEPARATOR": TokenType.SEPARATOR, "ENUM": TokenType.ENUM, "START": TokenType.BEGIN, + "SIGNED": TokenType.BIGINT, + "SIGNED INTEGER": TokenType.BIGINT, + "UNSIGNED": TokenType.UBIGINT, + "UNSIGNED INTEGER": TokenType.UBIGINT, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -441,6 +445,17 @@ class MySQL(Dialect): LIMIT_FETCH = "LIMIT" + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead.""" + if expression.to.this == exp.DataType.Type.BIGINT: + to = "SIGNED" + elif expression.to.this == exp.DataType.Type.UBIGINT: + to = "UNSIGNED" + else: + return super().cast_sql(expression) + + return f"CAST({self.sql(expression, 'this')} AS {to})" + def show_sql(self, expression: exp.Show) -> str: this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 8c2a4ab..766b584 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, max_or_greatest, min_or_least, + no_map_from_entries_sql, no_paren_current_date_sql, no_pivot_sql, no_tablesample_sql, @@ -346,6 +347,7 @@ class Postgres(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Max: max_or_greatest, + exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), @@ -378,3 +380,11 @@ class Postgres(Dialect): exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + + def bracket_sql(self, expression: exp.Bracket) -> str: + """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY.""" + if isinstance(expression.this, exp.Array): + expression = expression.copy() + expression.set("this", exp.paren(expression.this, copy=False)) + + return super().bracket_sql(expression) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 265780e..24c439b 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -20,7 +20,7 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.mysql import MySQL from sqlglot.errors import UnsupportedError -from sqlglot.helper import seq_get +from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -154,6 +154,13 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) +def _parse_element_at(args: t.List) -> exp.SafeBracket: + this = seq_get(args, 0) + index = seq_get(args, 1) + assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) + return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1)) + + def _unnest_sequence(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Table): if isinstance(expression.this, exp.GenerateSeries): @@ -201,6 +208,7 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, + "ELEMENT_AT": _parse_element_at, "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, "FROM_UTF8": lambda args: exp.Decode( @@ -285,6 +293,9 @@ class Presto(Dialect): exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, exp.Right: right_to_substring_sql, + exp.SafeBracket: lambda self, e: self.func( + "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) + ), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index db6cc3f..87be42c 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -41,8 +41,6 @@ class Redshift(Postgres): "STRTOL": exp.FromBase.from_arg_list, } - CONVERT_TYPE_FIRST = True - def _parse_types( self, check_func: bool = False, schema: bool = False ) -> t.Optional[exp.Expression]: @@ -58,6 +56,12 @@ class Redshift(Postgres): return this + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: + to = self._parse_types() + self._match(TokenType.COMMA) + this = self._parse_bitwise() + return self.expression(exp.TryCast, this=this, to=to) + class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 1f620df..a2dbfd9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -258,14 +258,29 @@ class Snowflake(Dialect): ALTER_PARSERS = { **parser.Parser.ALTER_PARSERS, - "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), - "SET": lambda self: self._parse_alter_table_set_tag(), + "SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")), + "UNSET": lambda self: self.expression( + exp.Set, + tag=self._match_text_seq("TAG"), + expressions=self._parse_csv(self._parse_id_var), + unset=True, + ), } - def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: - self._match_text_seq("TAG") - parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction) - return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset) + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + if self._match_text_seq("IDENTIFIER", "("): + identifier = ( + super()._parse_id_var(any_token=any_token, tokens=tokens) + or self._parse_string() + ) + self._match_r_paren() + return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier]) + + return super()._parse_id_var(any_token=any_token, tokens=tokens) class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] @@ -380,10 +395,6 @@ class Snowflake(Dialect): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) - def settag_sql(self, expression: exp.SetTag) -> str: - action = "UNSET" if expression.args.get("unset") else "SET" - return f"{action} TAG {self.expressions(expression)}" - def describe_sql(self, expression: exp.Describe) -> str: # Default to table if kind is unknown kind_value = expression.args.get("kind") or "TABLE" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index b7d1641..7a7ee01 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -43,6 +43,7 @@ class Spark(Spark2): class Generator(Spark2.Generator): TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() TRANSFORMS.pop(exp.DateDiff) + TRANSFORMS.pop(exp.Group) def datediff_sql(self, expression: exp.DateDiff) -> str: unit = self.sql(expression, "unit") diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index ed6992d..3720b8d 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -231,14 +231,14 @@ class Spark2(Hive): WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False - def cast_sql(self, expression: exp.Cast) -> str: + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"): schema = f"'{self.sql(expression, 'to')}'" return self.func("FROM_JSON", expression.this.this, schema) if expression.is_type("json"): return self.func("TO_JSON", expression.this) - return super(Hive.Generator, self).cast_sql(expression) + return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix) def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: return super().columndef_sql( diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 803f361..519e62a 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -133,7 +135,7 @@ class SQLite(Dialect): LIMIT_FETCH = "LIMIT" - def cast_sql(self, expression: exp.Cast) -> str: + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if expression.is_type("date"): return self.func("DATE", expression.this) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 6d674f5..f671630 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -166,6 +166,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s class TSQL(Dialect): + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index c405c45..630cb65 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -63,11 +63,9 @@ class Context: reader = table[i] yield reader, self - def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]: + def table_iter(self, table: str) -> TableIter: self.env["scope"] = self.row_readers - - for reader in self.tables[table]: - yield reader, self + return iter(self.tables[table]) def filter(self, condition) -> None: rows = [reader.row for reader, _ in self if self.eval(condition)] diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 635ec2c..a927181 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -276,11 +276,9 @@ class PythonExecutor: end = 1 length = len(context.table) table = self.table(list(step.group) + step.aggregations) - condition = self.generate(step.condition) def add_row(): - if not condition or context.eval(condition): - table.append(group + context.eval_tuple(aggregations)) + table.append(group + context.eval_tuple(aggregations)) if length: for i in range(length): @@ -304,7 +302,7 @@ class PythonExecutor: context = self.context({step.name: table, **{name: table for name in context.tables}}) - if step.projections: + if step.projections or step.condition: return self.scan(step, context) return context diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1c0af58..e01cc1a 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1013,7 +1013,7 @@ class Pragma(Expression): class Set(Expression): - arg_types = {"expressions": False} + arg_types = {"expressions": False, "unset": False, "tag": False} class SetItem(Expression): @@ -1168,10 +1168,6 @@ class RenameTable(Expression): pass -class SetTag(Expression): - arg_types = {"expressions": True, "unset": False} - - class Comment(Expression): arg_types = {"this": True, "kind": True, "expression": True, "exists": False} @@ -1934,6 +1930,11 @@ class LanguageProperty(Property): arg_types = {"this": True} +# spark ddl +class ClusteredByProperty(Property): + arg_types = {"expressions": True, "sorted_by": False, "buckets": True} + + class DictProperty(Property): arg_types = {"this": True, "kind": True, "settings": False} @@ -2074,6 +2075,7 @@ class Properties(Expression): "ALGORITHM": AlgorithmProperty, "AUTO_INCREMENT": AutoIncrementProperty, "CHARACTER SET": CharacterSetProperty, + "CLUSTERED_BY": ClusteredByProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, "DEFINER": DefinerProperty, @@ -2280,6 +2282,12 @@ class Table(Expression): "system_time": False, } + @property + def name(self) -> str: + if isinstance(self.this, Func): + return "" + return self.this.name + @property def db(self) -> str: return self.text("db") @@ -3716,6 +3724,10 @@ class Bracket(Condition): arg_types = {"this": True, "expressions": True} +class SafeBracket(Bracket): + """Represents array lookup where OOB index yields NULL instead of causing a failure.""" + + class Distinct(Expression): arg_types = {"expressions": False, "on": False} @@ -3934,7 +3946,7 @@ class Case(Func): class Cast(Func): - arg_types = {"this": True, "to": True} + arg_types = {"this": True, "to": True, "format": False} @property def name(self) -> str: @@ -4292,6 +4304,10 @@ class Map(Func): arg_types = {"keys": False, "values": False} +class MapFromEntries(Func): + pass + + class StarMap(Func): pass diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 81e0ac3..5d8a4ca 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -188,6 +188,7 @@ class Generator: exp.CollateProperty: exp.Properties.Location.POST_SCHEMA, exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA, exp.Cluster: exp.Properties.Location.POST_SCHEMA, + exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA, exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME, exp.DefinerProperty: exp.Properties.Location.POST_CREATE, exp.DictRange: exp.Properties.Location.POST_SCHEMA, @@ -1408,7 +1409,8 @@ class Generator: expressions = ( f" {self.expressions(expression, flat=True)}" if expression.expressions else "" ) - return f"SET{expressions}" + tag = " TAG" if expression.args.get("tag") else "" + return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}" def pragma_sql(self, expression: exp.Pragma) -> str: return f"PRAGMA {self.sql(expression, 'this')}" @@ -1749,6 +1751,9 @@ class Generator: return f"{self.sql(expression, 'this')}[{expressions_sql}]" + def safebracket_sql(self, expression: exp.SafeBracket) -> str: + return self.bracket_sql(expression) + def all_sql(self, expression: exp.All) -> str: return f"ALL {self.wrap(expression)}" @@ -2000,8 +2005,10 @@ class Generator: def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str: return self.binary(expression, "^") - def cast_sql(self, expression: exp.Cast) -> str: - return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + format_sql = self.sql(expression, "format") + format_sql = f" FORMAT {format_sql}" if format_sql else "" + return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})" def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") @@ -2227,7 +2234,7 @@ class Generator: return self.binary(expression, "-") def trycast_sql(self, expression: exp.TryCast) -> str: - return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + return self.cast_sql(expression, safe_prefix="TRY_") def use_sql(self, expression: exp.Use) -> str: kind = self.sql(expression, "kind") @@ -2409,6 +2416,13 @@ class Generator: def oncluster_sql(self, expression: exp.OnCluster) -> str: return "" + def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str: + expressions = self.expressions(expression, key="expressions", flat=True) + sorted_by = self.expressions(expression, key="sorted_by", flat=True) + sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else "" + buckets = self.sql(expression, "buckets") + return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 5fdbde8..6e15c6a 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -60,8 +60,8 @@ def qualify( The qualified expression. """ schema = ensure_schema(schema, dialect=dialect) - expression = normalize_identifiers(expression, dialect=dialect) expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) + expression = normalize_identifiers(expression, dialect=dialect) if isolate_tables: expression = isolate_table_selects(expression, schema=schema) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ac8eb0f..ef8aeb1 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -56,13 +56,13 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) - _expand_group_by(scope, resolver) - _expand_order_by(scope) + _expand_group_by(scope) + _expand_order_by(scope, resolver) return expression -def validate_qualify_columns(expression): +def validate_qualify_columns(expression: E) -> E: """Raise an `OptimizeError` if any columns aren't qualified""" unqualified_columns = [] for scope in traverse_scope(expression): @@ -79,7 +79,7 @@ def validate_qualify_columns(expression): return expression -def _pop_table_column_aliases(derived_tables): +def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: """ Remove table column aliases. @@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables): table_alias.args.pop("columns", None) -def _expand_using(scope, resolver): +def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: joins = list(scope.find_all(exp.Join)) names = {join.alias_or_name for join in joins} ordered = [key for key in scope.selected_sources if key not in names] # Mapping of automatically joined column names to an ordered set of source names (dict). - column_tables = {} + column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} for join in joins: using = join.args.get("using") @@ -172,20 +172,25 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: alias_to_expression: t.Dict[str, exp.Expression] = {} - def replace_columns( - node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False - ): + def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None: if not node: return for column, *_ in walk_in_scope(node): if not isinstance(column, exp.Column): continue - table = resolver.get_table(column.name) if resolve_agg and not column.table else None - if table and column.find_ancestor(exp.AggFunc): + table = resolver.get_table(column.name) if resolve_table and not column.table else None + alias_expr = alias_to_expression.get(column.name) + double_agg = ( + (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) + if alias_expr + else False + ) + + if table and (not alias_expr or double_agg): column.set("table", table) - elif expand and not column.table and column.name in alias_to_expression: - column.replace(alias_to_expression[column.name].copy()) + elif not column.table and alias_expr and not double_agg: + column.replace(alias_expr.copy()) for projection in scope.selects: replace_columns(projection) @@ -195,22 +200,41 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: replace_columns(expression.args.get("where")) replace_columns(expression.args.get("group")) - replace_columns(expression.args.get("having"), resolve_agg=True) - replace_columns(expression.args.get("qualify"), resolve_agg=True) - replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) + replace_columns(expression.args.get("having"), resolve_table=True) + replace_columns(expression.args.get("qualify"), resolve_table=True) scope.clear_cache() -def _expand_group_by(scope, resolver): - group = scope.expression.args.get("group") +def _expand_group_by(scope: Scope): + expression = scope.expression + group = expression.args.get("group") if not group: return group.set("expressions", _expand_positional_references(scope, group.expressions)) - scope.expression.set("group", group) + expression.set("group", group) + + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + groups = set(group.expressions) + group.meta["final"] = True + + for e in expression.selects: + for node, *_ in e.walk(): + if node in groups: + e.meta["final"] = True + break + having = expression.args.get("having") + if having: + for node, *_ in having.walk(): + if node in groups: + having.meta["final"] = True + break -def _expand_order_by(scope): + +def _expand_order_by(scope: Scope, resolver: Resolver): order = scope.expression.args.get("order") if not order: return @@ -220,10 +244,21 @@ def _expand_order_by(scope): ordereds, _expand_positional_references(scope, (o.this for o in ordereds)), ): + for agg in ordered.find_all(exp.AggFunc): + for col in agg.find_all(exp.Column): + if not col.table: + col.set("table", resolver.get_table(col.name)) + ordered.set("this", new_expression) + if scope.expression.args.get("group"): + selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} + + for ordered in ordereds: + ordered.set("this", selects.get(ordered.this, ordered.this)) -def _expand_positional_references(scope, expressions): + +def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: new_nodes = [] for node in expressions: if node.is_int: @@ -241,7 +276,7 @@ def _expand_positional_references(scope, expressions): return new_nodes -def _qualify_columns(scope, resolver): +def _qualify_columns(scope: Scope, resolver: Resolver) -> None: """Disambiguate columns, ensuring each column specifies a source""" for column in scope.columns: column_table = column.table @@ -290,21 +325,23 @@ def _qualify_columns(scope, resolver): column.set("table", column_table) -def _expand_stars(scope, resolver, using_column_tables): +def _expand_stars( + scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any] +) -> None: """Expand stars to lists of column selections""" new_selections = [] - except_columns = {} - replace_columns = {} + except_columns: t.Dict[int, t.Set[str]] = {} + replace_columns: t.Dict[int, t.Dict[str, str]] = {} coalesced_columns = set() # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future pivot_columns = None pivot_output_columns = None - pivot = seq_get(scope.pivots, 0) + pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) has_pivoted_source = pivot and not pivot.args.get("unpivot") - if has_pivoted_source: + if pivot and has_pivoted_source: pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] @@ -330,8 +367,17 @@ def _expand_stars(scope, resolver, using_column_tables): columns = resolver.get_source_columns(table, only_visible=True) + # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement + # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table + if resolver.schema.dialect == "bigquery": + columns = [ + name + for name in columns + if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE") + ] + if columns and "*" not in columns: - if has_pivoted_source: + if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: implicit_columns = [col for col in columns if col not in pivot_columns] new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) @@ -368,7 +414,9 @@ def _expand_stars(scope, resolver, using_column_tables): scope.expression.set("expressions", new_selections) -def _add_except_columns(expression, tables, except_columns): +def _add_except_columns( + expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] +) -> None: except_ = expression.args.get("except") if not except_: @@ -380,7 +428,9 @@ def _add_except_columns(expression, tables, except_columns): except_columns[id(table)] = columns -def _add_replace_columns(expression, tables, replace_columns): +def _add_replace_columns( + expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] +) -> None: replace = expression.args.get("replace") if not replace: @@ -392,7 +442,7 @@ def _add_replace_columns(expression, tables, replace_columns): replace_columns[id(table)] = columns -def _qualify_outputs(scope): +def _qualify_outputs(scope: Scope): """Ensure all output columns are aliased""" new_selections = [] @@ -429,7 +479,7 @@ class Resolver: This is a class so we can lazily load some things and easily share them across functions. """ - def __init__(self, scope, schema, infer_schema: bool = True): + def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): self.scope = scope self.schema = schema self._source_columns = None diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 5365aef..34005d9 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -28,6 +28,8 @@ def simplify(expression): generate = cached_generator() def _simplify(expression, root=True): + if expression.meta.get("final"): + return expression node = expression node = rewrite_between(node) node = uniq_sort(node, generate, root) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index e16a88e..e5bd4ae 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -585,6 +585,7 @@ class Parser(metaclass=_Parser): "CHARACTER SET": lambda self: self._parse_character_set(), "CHECKSUM": lambda self: self._parse_checksum(), "CLUSTER BY": lambda self: self._parse_cluster(), + "CLUSTERED": lambda self: self._parse_clustered_by(), "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), "COPY": lambda self: self._parse_copy_property(), @@ -794,8 +795,6 @@ class Parser(metaclass=_Parser): # A NULL arg in CONCAT yields NULL by default CONCAT_NULL_OUTPUTS_STRING = False - CONVERT_TYPE_FIRST = False - PREFIXED_PIVOT_COLUMNS = False IDENTIFY_PIVOT_STRINGS = False @@ -1426,9 +1425,34 @@ class Parser(metaclass=_Parser): return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT)) - def _parse_cluster(self) -> t.Optional[exp.Cluster]: + def _parse_cluster(self) -> exp.Cluster: return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered)) + def _parse_clustered_by(self) -> exp.ClusteredByProperty: + self._match_text_seq("BY") + + self._match_l_paren() + expressions = self._parse_csv(self._parse_column) + self._match_r_paren() + + if self._match_text_seq("SORTED", "BY"): + self._match_l_paren() + sorted_by = self._parse_csv(self._parse_ordered) + self._match_r_paren() + else: + sorted_by = None + + self._match(TokenType.INTO) + buckets = self._parse_number() + self._match_text_seq("BUCKETS") + + return self.expression( + exp.ClusteredByProperty, + expressions=expressions, + sorted_by=sorted_by, + buckets=buckets, + ) + def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]: if not self._match_text_seq("GRANTS"): self._retreat(self._index - 1) @@ -2863,7 +2887,11 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.INTERVAL): return None - this = self._parse_primary() or self._parse_term() + if self._match(TokenType.STRING, advance=False): + this = self._parse_primary() + else: + this = self._parse_term() + unit = self._parse_function() or self._parse_var() # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse @@ -3661,6 +3689,7 @@ class Parser(metaclass=_Parser): else: self.raise_error("Expected AS after CAST") + fmt = None to = self._parse_types() if not to: @@ -3668,22 +3697,23 @@ class Parser(metaclass=_Parser): elif to.this == exp.DataType.Type.CHAR: if self._match(TokenType.CHARACTER_SET): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) - elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT): - fmt = self._parse_string() + elif self._match(TokenType.FORMAT): + fmt = self._parse_at_time_zone(self._parse_string()) - return self.expression( - exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, - this=this, - format=exp.Literal.string( - format_time( - fmt.this if fmt else "", - self.FORMAT_MAPPING or self.TIME_MAPPING, - self.FORMAT_TRIE or self.TIME_TRIE, - ) - ), - ) + if to.this in exp.DataType.TEMPORAL_TYPES: + return self.expression( + exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, + this=this, + format=exp.Literal.string( + format_time( + fmt.this if fmt else "", + self.FORMAT_MAPPING or self.TIME_MAPPING, + self.FORMAT_TRIE or self.TIME_TRIE, + ) + ), + ) - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt) def _parse_concat(self) -> t.Optional[exp.Expression]: args = self._parse_csv(self._parse_conjunction) @@ -3704,20 +3734,23 @@ class Parser(metaclass=_Parser): ) def _parse_string_agg(self) -> exp.Expression: - expression: t.Optional[exp.Expression] - if self._match(TokenType.DISTINCT): - args = self._parse_csv(self._parse_conjunction) - expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) + args: t.List[t.Optional[exp.Expression]] = [ + self.expression(exp.Distinct, expressions=[self._parse_conjunction()]) + ] + if self._match(TokenType.COMMA): + args.extend(self._parse_csv(self._parse_conjunction)) else: args = self._parse_csv(self._parse_conjunction) - expression = seq_get(args, 0) index = self._index if not self._match(TokenType.R_PAREN): # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) - order = self._parse_order(this=expression) - return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) + return self.expression( + exp.GroupConcat, + this=seq_get(args, 0), + separator=self._parse_order(this=seq_get(args, 1)), + ) # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY [ASC | DESC]). # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that @@ -3727,24 +3760,21 @@ class Parser(metaclass=_Parser): return self.validate_expression(exp.GroupConcat.from_arg_list(args), args) self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller) - order = self._parse_order(this=expression) + order = self._parse_order(this=seq_get(args, 0)) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: - to: t.Optional[exp.Expression] this = self._parse_bitwise() if self._match(TokenType.USING): - to = self.expression(exp.CharacterSet, this=self._parse_var()) + to: t.Optional[exp.Expression] = self.expression( + exp.CharacterSet, this=self._parse_var() + ) elif self._match(TokenType.COMMA): - to = self._parse_bitwise() + to = self._parse_types() else: to = None - # Swap the argument order if needed to produce the correct AST - if self.CONVERT_TYPE_FIRST: - this, to = to, this - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]: @@ -4394,8 +4424,8 @@ class Parser(metaclass=_Parser): if self._next: self._advance() - parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None + parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None if parser: actions = ensure_list(parser(self)) @@ -4516,9 +4546,11 @@ class Parser(metaclass=_Parser): parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE) return parser(self) if parser else self._parse_set_item_assignment(kind=None) - def _parse_set(self) -> exp.Set | exp.Command: + def _parse_set(self, unset: bool = False, tag: bool = False) -> exp.Set | exp.Command: index = self._index - set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + set_ = self.expression( + exp.Set, expressions=self._parse_csv(self._parse_set_item), unset=unset, tag=tag + ) if self._curr: self._retreat(index) @@ -4683,12 +4715,8 @@ class Parser(metaclass=_Parser): exp.replace_children(this, self._replace_columns_with_dots) table = this.args.get("table") this = ( - self.expression(exp.Dot, this=table, expression=this.this) - if table - else self.expression(exp.Var, this=this.name) + self.expression(exp.Dot, this=table, expression=this.this) if table else this.this ) - elif isinstance(this, exp.Identifier): - this = self.expression(exp.Var, this=this.name) return this diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 4ed7449..f246702 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -91,6 +91,7 @@ class Step: A Step DAG corresponding to `expression`. """ ctes = ctes or {} + expression = expression.unnest() with_ = expression.args.get("with") # CTEs break the mold of scope and introduce themselves to all in the context. @@ -120,22 +121,25 @@ class Step: projections = [] # final selects in this chain of steps representing a select operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) - aggregations = [] + aggregations = set() next_operand_name = name_sequence("_a_") def extract_agg_operands(expression): - for agg in expression.find_all(exp.AggFunc): + agg_funcs = tuple(expression.find_all(exp.AggFunc)) + if agg_funcs: + aggregations.add(expression) + for agg in agg_funcs: for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = next_operand_name() operand.replace(exp.column(operands[operand], quoted=True)) + return bool(agg_funcs) for e in expression.expressions: if e.find(exp.AggFunc): projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) - aggregations.append(e) extract_agg_operands(e) else: projections.append(e) @@ -155,22 +159,38 @@ class Step: having = expression.args.get("having") if having: - extract_agg_operands(having) - aggregate.condition = having.this + if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)): + aggregate.condition = exp.column("_h", step.name, quoted=True) + else: + aggregate.condition = having.this aggregate.operands = tuple( alias(operand, alias_) for operand, alias_ in operands.items() ) - aggregate.aggregations = aggregations + aggregate.aggregations = list(aggregations) + # give aggregates names and replace projections with references to them aggregate.group = { f"_g{i}": e for i, e in enumerate(group.expressions if group else []) } + + intermediate: t.Dict[str | exp.Expression, str] = {} + for k, v in aggregate.group.items(): + intermediate[v] = k + if isinstance(v, exp.Column): + intermediate[v.alias_or_name] = k + for projection in projections: - for i, e in aggregate.group.items(): - for child, *_ in projection.walk(): - if child == e: - child.replace(exp.column(i, step.name)) + for node, *_ in projection.walk(): + name = intermediate.get(node) + if name: + node.replace(exp.column(name, step.name)) + if aggregate.condition: + for node, *_ in aggregate.condition.walk(): + name = intermediate.get(node) or intermediate.get(node.name) + if name: + node.replace(exp.column(name, step.name)) + aggregate.add_dependency(step) step = aggregate diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index ba72616..1f30f96 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -159,10 +159,11 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Select): from sqlglot.optimizer.scope import build_scope - taken_select_names = set(expression.named_selects) scope = build_scope(expression) if not scope: return expression + + taken_select_names = set(expression.named_selects) taken_source_names = set(scope.selected_sources) for select in expression.selects: -- cgit v1.2.3