summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dialects/bigquery.py205
-rw-r--r--sqlglot/dialects/dialect.py5
-rw-r--r--sqlglot/dialects/mysql.py15
-rw-r--r--sqlglot/dialects/postgres.py10
-rw-r--r--sqlglot/dialects/presto.py13
-rw-r--r--sqlglot/dialects/redshift.py8
-rw-r--r--sqlglot/dialects/snowflake.py31
-rw-r--r--sqlglot/dialects/spark.py1
-rw-r--r--sqlglot/dialects/spark2.py4
-rw-r--r--sqlglot/dialects/sqlite.py4
-rw-r--r--sqlglot/dialects/tsql.py1
-rw-r--r--sqlglot/executor/context.py6
-rw-r--r--sqlglot/executor/python.py6
-rw-r--r--sqlglot/expressions.py28
-rw-r--r--sqlglot/generator.py22
-rw-r--r--sqlglot/optimizer/qualify.py2
-rw-r--r--sqlglot/optimizer/qualify_columns.py114
-rw-r--r--sqlglot/optimizer/simplify.py2
-rw-r--r--sqlglot/parser.py110
-rw-r--r--sqlglot/planner.py40
-rw-r--r--sqlglot/transforms.py3
22 files changed, 500 insertions, 132 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 71385aa..bdc1fb4 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -1119,7 +1119,7 @@ def map_entries(col: ColumnOrName) -> Column:
def map_from_entries(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES")
+ return Column.invoke_expression_over_column(col, expression.MapFromEntries)
def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 52d4a88..8786063 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import logging
import re
import typing as t
@@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
+logger = logging.getLogger("sqlglot")
+
def _date_add_sql(
data_type: str, kind: str
@@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
return expression
+# https://issuetracker.google.com/issues/162294746
+# workaround for bigquery bug when grouping by an expression and then ordering
+# WITH x AS (SELECT 1 y)
+# SELECT y + 1 z
+# FROM x
+# GROUP BY x + 1
+# ORDER by z
+def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Select):
+ group = expression.args.get("group")
+ order = expression.args.get("order")
+
+ if group and order:
+ aliases = {
+ select.this: select.args["alias"]
+ for select in expression.selects
+ if isinstance(select, exp.Alias)
+ }
+
+ for e in group.expressions:
+ alias = aliases.get(e)
+
+ if alias:
+ e.replace(exp.column(alias))
+
+ return expression
+
+
+def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
+ """BigQuery doesn't allow column names when defining a CTE, so we try to push them down."""
+ if isinstance(expression, exp.CTE) and expression.alias_column_names:
+ cte_query = expression.this
+
+ if cte_query.is_star:
+ logger.warning(
+ "Can't push down CTE column names for star queries. Run the query through"
+ " the optimizer or use 'qualify' to expand the star projections first."
+ )
+ return expression
+
+ column_names = expression.alias_column_names
+ expression.args["alias"].set("columns", None)
+
+ for name, select in zip(column_names, cte_query.selects):
+ to_replace = select
+
+ if isinstance(select, exp.Alias):
+ select = select.this
+
+ # Inner aliases are shadowed by the CTE column names
+ to_replace.replace(exp.alias_(select, name))
+
+ return expression
+
+
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ # bigquery udfs are case sensitive
+ NORMALIZE_FUNCTIONS = False
+
TIME_MAPPING = {
"%D": "%m/%d/%y",
}
@@ -135,12 +196,16 @@ class BigQuery(Dialect):
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
# The following check is essentially a heuristic to detect tables based on whether or
# not they're qualified.
- if (
- isinstance(expression, exp.Identifier)
- and not (isinstance(expression.parent, exp.Table) and expression.parent.db)
- and not expression.meta.get("is_table")
- ):
- expression.set("this", expression.this.lower())
+ if isinstance(expression, exp.Identifier):
+ parent = expression.parent
+
+ while isinstance(parent, exp.Dot):
+ parent = parent.parent
+
+ if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get(
+ "is_table"
+ ):
+ expression.set("this", expression.this.lower())
return expression
@@ -298,10 +363,8 @@ class BigQuery(Dialect):
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
- exp.AtTimeZone: lambda self, e: self.func(
- "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
- ),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
+ exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@@ -325,7 +388,12 @@ class BigQuery(Dialect):
),
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.Select: transforms.preprocess(
- [_unqualify_unnest, transforms.eliminate_distinct_on]
+ [
+ transforms.explode_to_unnest,
+ _unqualify_unnest,
+ transforms.eliminate_distinct_on,
+ _alias_ordered_group,
+ ]
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@@ -334,7 +402,6 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -378,7 +445,121 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
+ # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
+ RESERVED_KEYWORDS = {
+ *generator.Generator.RESERVED_KEYWORDS,
+ "all",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "assert_rows_modified",
+ "at",
+ "between",
+ "by",
+ "case",
+ "cast",
+ "collate",
+ "contains",
+ "create",
+ "cross",
+ "cube",
+ "current",
+ "default",
+ "define",
+ "desc",
+ "distinct",
+ "else",
+ "end",
+ "enum",
+ "escape",
+ "except",
+ "exclude",
+ "exists",
+ "extract",
+ "false",
+ "fetch",
+ "following",
+ "for",
+ "from",
+ "full",
+ "group",
+ "grouping",
+ "groups",
+ "hash",
+ "having",
+ "if",
+ "ignore",
+ "in",
+ "inner",
+ "intersect",
+ "interval",
+ "into",
+ "is",
+ "join",
+ "lateral",
+ "left",
+ "like",
+ "limit",
+ "lookup",
+ "merge",
+ "natural",
+ "new",
+ "no",
+ "not",
+ "null",
+ "nulls",
+ "of",
+ "on",
+ "or",
+ "order",
+ "outer",
+ "over",
+ "partition",
+ "preceding",
+ "proto",
+ "qualify",
+ "range",
+ "recursive",
+ "respect",
+ "right",
+ "rollup",
+ "rows",
+ "select",
+ "set",
+ "some",
+ "struct",
+ "tablesample",
+ "then",
+ "to",
+ "treat",
+ "true",
+ "unbounded",
+ "union",
+ "unnest",
+ "using",
+ "when",
+ "where",
+ "window",
+ "with",
+ "within",
+ }
+
+ def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
+ if not isinstance(expression.parent, exp.Cast):
+ return self.func(
+ "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone"))
+ )
+ return super().attimezone_sql(expression)
+
+ def trycast_sql(self, expression: exp.TryCast) -> str:
+ return self.cast_sql(expression, safe_prefix="SAFE_")
+
+ def cte_sql(self, expression: exp.CTE) -> str:
+ if expression.alias_column_names:
+ self.unsupported("Column names in CTE definition are not supported.")
+ return super().cte_sql(expression)
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0e25b9b..d258826 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -388,6 +388,11 @@ def no_comment_column_constraint_sql(
return ""
+def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
+ self.unsupported("MAP_FROM_ENTRIES unsupported")
+ return ""
+
+
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1dd2096..5f743ee 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -132,6 +132,10 @@ class MySQL(Dialect):
"SEPARATOR": TokenType.SEPARATOR,
"ENUM": TokenType.ENUM,
"START": TokenType.BEGIN,
+ "SIGNED": TokenType.BIGINT,
+ "SIGNED INTEGER": TokenType.BIGINT,
+ "UNSIGNED": TokenType.UBIGINT,
+ "UNSIGNED INTEGER": TokenType.UBIGINT,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -441,6 +445,17 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
+ if expression.to.this == exp.DataType.Type.BIGINT:
+ to = "SIGNED"
+ elif expression.to.this == exp.DataType.Type.UBIGINT:
+ to = "UNSIGNED"
+ else:
+ return super().cast_sql(expression)
+
+ return f"CAST({self.sql(expression, 'this')} AS {to})"
+
def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 8c2a4ab..766b584 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
max_or_greatest,
min_or_least,
+ no_map_from_entries_sql,
no_paren_current_date_sql,
no_pivot_sql,
no_tablesample_sql,
@@ -346,6 +347,7 @@ class Postgres(Dialect):
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
+ exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@@ -378,3 +380,11 @@ class Postgres(Dialect):
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
+ if isinstance(expression.this, exp.Array):
+ expression = expression.copy()
+ expression.set("this", exp.paren(expression.this, copy=False))
+
+ return super().bracket_sql(expression)
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 265780e..24c439b 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -20,7 +20,7 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
-from sqlglot.helper import seq_get
+from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@@ -154,6 +154,13 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
+def _parse_element_at(args: t.List) -> exp.SafeBracket:
+ this = seq_get(args, 0)
+ index = seq_get(args, 1)
+ assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
+ return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
+
+
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@@ -201,6 +208,7 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
+ "ELEMENT_AT": _parse_element_at,
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@@ -285,6 +293,9 @@ class Presto(Dialect):
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.Right: right_to_substring_sql,
+ exp.SafeBracket: lambda self, e: self.func(
+ "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
+ ),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index db6cc3f..87be42c 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -41,8 +41,6 @@ class Redshift(Postgres):
"STRTOL": exp.FromBase.from_arg_list,
}
- CONVERT_TYPE_FIRST = True
-
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@@ -58,6 +56,12 @@ class Redshift(Postgres):
return this
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
+ to = self._parse_types()
+ self._match(TokenType.COMMA)
+ this = self._parse_bitwise()
+ return self.expression(exp.TryCast, this=this, to=to)
+
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 1f620df..a2dbfd9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -258,14 +258,29 @@ class Snowflake(Dialect):
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS,
- "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
- "SET": lambda self: self._parse_alter_table_set_tag(),
+ "SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")),
+ "UNSET": lambda self: self.expression(
+ exp.Set,
+ tag=self._match_text_seq("TAG"),
+ expressions=self._parse_csv(self._parse_id_var),
+ unset=True,
+ ),
}
- def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
- self._match_text_seq("TAG")
- parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
- return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
+ def _parse_id_var(
+ self,
+ any_token: bool = True,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
+ if self._match_text_seq("IDENTIFIER", "("):
+ identifier = (
+ super()._parse_id_var(any_token=any_token, tokens=tokens)
+ or self._parse_string()
+ )
+ self._match_r_paren()
+ return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier])
+
+ return super()._parse_id_var(any_token=any_token, tokens=tokens)
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
@@ -380,10 +395,6 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
- def settag_sql(self, expression: exp.SetTag) -> str:
- action = "UNSET" if expression.args.get("unset") else "SET"
- return f"{action} TAG {self.expressions(expression)}"
-
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index b7d1641..7a7ee01 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -43,6 +43,7 @@ class Spark(Spark2):
class Generator(Spark2.Generator):
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
+ TRANSFORMS.pop(exp.Group)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index ed6992d..3720b8d 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -231,14 +231,14 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
- def cast_sql(self, expression: exp.Cast) -> str:
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.is_type("json"):
return self.func("TO_JSON", expression.this)
- return super(Hive.Generator, self).cast_sql(expression)
+ return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 803f361..519e62a 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -133,7 +135,7 @@ class SQLite(Dialect):
LIMIT_FETCH = "LIMIT"
- def cast_sql(self, expression: exp.Cast) -> str:
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.is_type("date"):
return self.func("DATE", expression.this)
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 6d674f5..f671630 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -166,6 +166,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
class TSQL(Dialect):
+ RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index c405c45..630cb65 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -63,11 +63,9 @@ class Context:
reader = table[i]
yield reader, self
- def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]:
+ def table_iter(self, table: str) -> TableIter:
self.env["scope"] = self.row_readers
-
- for reader in self.tables[table]:
- yield reader, self
+ return iter(self.tables[table])
def filter(self, condition) -> None:
rows = [reader.row for reader, _ in self if self.eval(condition)]
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 635ec2c..a927181 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -276,11 +276,9 @@ class PythonExecutor:
end = 1
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
- condition = self.generate(step.condition)
def add_row():
- if not condition or context.eval(condition):
- table.append(group + context.eval_tuple(aggregations))
+ table.append(group + context.eval_tuple(aggregations))
if length:
for i in range(length):
@@ -304,7 +302,7 @@ class PythonExecutor:
context = self.context({step.name: table, **{name: table for name in context.tables}})
- if step.projections:
+ if step.projections or step.condition:
return self.scan(step, context)
return context
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 1c0af58..e01cc1a 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1013,7 +1013,7 @@ class Pragma(Expression):
class Set(Expression):
- arg_types = {"expressions": False}
+ arg_types = {"expressions": False, "unset": False, "tag": False}
class SetItem(Expression):
@@ -1168,10 +1168,6 @@ class RenameTable(Expression):
pass
-class SetTag(Expression):
- arg_types = {"expressions": True, "unset": False}
-
-
class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
@@ -1934,6 +1930,11 @@ class LanguageProperty(Property):
arg_types = {"this": True}
+# spark ddl
+class ClusteredByProperty(Property):
+ arg_types = {"expressions": True, "sorted_by": False, "buckets": True}
+
+
class DictProperty(Property):
arg_types = {"this": True, "kind": True, "settings": False}
@@ -2074,6 +2075,7 @@ class Properties(Expression):
"ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
+ "CLUSTERED_BY": ClusteredByProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"DEFINER": DefinerProperty,
@@ -2281,6 +2283,12 @@ class Table(Expression):
}
@property
+ def name(self) -> str:
+ if isinstance(self.this, Func):
+ return ""
+ return self.this.name
+
+ @property
def db(self) -> str:
return self.text("db")
@@ -3716,6 +3724,10 @@ class Bracket(Condition):
arg_types = {"this": True, "expressions": True}
+class SafeBracket(Bracket):
+ """Represents array lookup where OOB index yields NULL instead of causing a failure."""
+
+
class Distinct(Expression):
arg_types = {"expressions": False, "on": False}
@@ -3934,7 +3946,7 @@ class Case(Func):
class Cast(Func):
- arg_types = {"this": True, "to": True}
+ arg_types = {"this": True, "to": True, "format": False}
@property
def name(self) -> str:
@@ -4292,6 +4304,10 @@ class Map(Func):
arg_types = {"keys": False, "values": False}
+class MapFromEntries(Func):
+ pass
+
+
class StarMap(Func):
pass
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 81e0ac3..5d8a4ca 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -188,6 +188,7 @@ class Generator:
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
+ exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
@@ -1408,7 +1409,8 @@ class Generator:
expressions = (
f" {self.expressions(expression, flat=True)}" if expression.expressions else ""
)
- return f"SET{expressions}"
+ tag = " TAG" if expression.args.get("tag") else ""
+ return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}"
def pragma_sql(self, expression: exp.Pragma) -> str:
return f"PRAGMA {self.sql(expression, 'this')}"
@@ -1749,6 +1751,9 @@ class Generator:
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
+ def safebracket_sql(self, expression: exp.SafeBracket) -> str:
+ return self.bracket_sql(expression)
+
def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
@@ -2000,8 +2005,10 @@ class Generator:
def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
return self.binary(expression, "^")
- def cast_sql(self, expression: exp.Cast) -> str:
- return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ format_sql = self.sql(expression, "format")
+ format_sql = f" FORMAT {format_sql}" if format_sql else ""
+ return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@@ -2227,7 +2234,7 @@ class Generator:
return self.binary(expression, "-")
def trycast_sql(self, expression: exp.TryCast) -> str:
- return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
+ return self.cast_sql(expression, safe_prefix="TRY_")
def use_sql(self, expression: exp.Use) -> str:
kind = self.sql(expression, "kind")
@@ -2409,6 +2416,13 @@ class Generator:
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return ""
+ def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
+ expressions = self.expressions(expression, key="expressions", flat=True)
+ sorted_by = self.expressions(expression, key="sorted_by", flat=True)
+ sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else ""
+ buckets = self.sql(expression, "buckets")
+ return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS"
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py
index 5fdbde8..6e15c6a 100644
--- a/sqlglot/optimizer/qualify.py
+++ b/sqlglot/optimizer/qualify.py
@@ -60,8 +60,8 @@ def qualify(
The qualified expression.
"""
schema = ensure_schema(schema, dialect=dialect)
- expression = normalize_identifiers(expression, dialect=dialect)
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
+ expression = normalize_identifiers(expression, dialect=dialect)
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ac8eb0f..ef8aeb1 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -56,13 +56,13 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
- _expand_group_by(scope, resolver)
- _expand_order_by(scope)
+ _expand_group_by(scope)
+ _expand_order_by(scope, resolver)
return expression
-def validate_qualify_columns(expression):
+def validate_qualify_columns(expression: E) -> E:
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
for scope in traverse_scope(expression):
@@ -79,7 +79,7 @@ def validate_qualify_columns(expression):
return expression
-def _pop_table_column_aliases(derived_tables):
+def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
"""
Remove table column aliases.
@@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables):
table_alias.args.pop("columns", None)
-def _expand_using(scope, resolver):
+def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
joins = list(scope.find_all(exp.Join))
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to an ordered set of source names (dict).
- column_tables = {}
+ column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
for join in joins:
using = join.args.get("using")
@@ -172,20 +172,25 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
alias_to_expression: t.Dict[str, exp.Expression] = {}
- def replace_columns(
- node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
- ):
+ def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
if not node:
return
for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
- table = resolver.get_table(column.name) if resolve_agg and not column.table else None
- if table and column.find_ancestor(exp.AggFunc):
+ table = resolver.get_table(column.name) if resolve_table and not column.table else None
+ alias_expr = alias_to_expression.get(column.name)
+ double_agg = (
+ (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
+ if alias_expr
+ else False
+ )
+
+ if table and (not alias_expr or double_agg):
column.set("table", table)
- elif expand and not column.table and column.name in alias_to_expression:
- column.replace(alias_to_expression[column.name].copy())
+ elif not column.table and alias_expr and not double_agg:
+ column.replace(alias_expr.copy())
for projection in scope.selects:
replace_columns(projection)
@@ -195,22 +200,41 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
replace_columns(expression.args.get("where"))
replace_columns(expression.args.get("group"))
- replace_columns(expression.args.get("having"), resolve_agg=True)
- replace_columns(expression.args.get("qualify"), resolve_agg=True)
- replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
+ replace_columns(expression.args.get("having"), resolve_table=True)
+ replace_columns(expression.args.get("qualify"), resolve_table=True)
scope.clear_cache()
-def _expand_group_by(scope, resolver):
- group = scope.expression.args.get("group")
+def _expand_group_by(scope: Scope):
+ expression = scope.expression
+ group = expression.args.get("group")
if not group:
return
group.set("expressions", _expand_positional_references(scope, group.expressions))
- scope.expression.set("group", group)
+ expression.set("group", group)
+
+ # group by expressions cannot be simplified, for example
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
+ # the projection must exactly match the group by key
+ groups = set(group.expressions)
+ group.meta["final"] = True
+
+ for e in expression.selects:
+ for node, *_ in e.walk():
+ if node in groups:
+ e.meta["final"] = True
+ break
+ having = expression.args.get("having")
+ if having:
+ for node, *_ in having.walk():
+ if node in groups:
+ having.meta["final"] = True
+ break
-def _expand_order_by(scope):
+
+def _expand_order_by(scope: Scope, resolver: Resolver):
order = scope.expression.args.get("order")
if not order:
return
@@ -220,10 +244,21 @@ def _expand_order_by(scope):
ordereds,
_expand_positional_references(scope, (o.this for o in ordereds)),
):
+ for agg in ordered.find_all(exp.AggFunc):
+ for col in agg.find_all(exp.Column):
+ if not col.table:
+ col.set("table", resolver.get_table(col.name))
+
ordered.set("this", new_expression)
+ if scope.expression.args.get("group"):
+ selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
+
+ for ordered in ordereds:
+ ordered.set("this", selects.get(ordered.this, ordered.this))
-def _expand_positional_references(scope, expressions):
+
+def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
new_nodes = []
for node in expressions:
if node.is_int:
@@ -241,7 +276,7 @@ def _expand_positional_references(scope, expressions):
return new_nodes
-def _qualify_columns(scope, resolver):
+def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
@@ -290,21 +325,23 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
-def _expand_stars(scope, resolver, using_column_tables):
+def _expand_stars(
+ scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
+) -> None:
"""Expand stars to lists of column selections"""
new_selections = []
- except_columns = {}
- replace_columns = {}
+ except_columns: t.Dict[int, t.Set[str]] = {}
+ replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
- pivot = seq_get(scope.pivots, 0)
+ pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
has_pivoted_source = pivot and not pivot.args.get("unpivot")
- if has_pivoted_source:
+ if pivot and has_pivoted_source:
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
@@ -330,8 +367,17 @@ def _expand_stars(scope, resolver, using_column_tables):
columns = resolver.get_source_columns(table, only_visible=True)
+ # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
+ # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
+ if resolver.schema.dialect == "bigquery":
+ columns = [
+ name
+ for name in columns
+ if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
+ ]
+
if columns and "*" not in columns:
- if has_pivoted_source:
+ if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
@@ -368,7 +414,9 @@ def _expand_stars(scope, resolver, using_column_tables):
scope.expression.set("expressions", new_selections)
-def _add_except_columns(expression, tables, except_columns):
+def _add_except_columns(
+ expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
+) -> None:
except_ = expression.args.get("except")
if not except_:
@@ -380,7 +428,9 @@ def _add_except_columns(expression, tables, except_columns):
except_columns[id(table)] = columns
-def _add_replace_columns(expression, tables, replace_columns):
+def _add_replace_columns(
+ expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
+) -> None:
replace = expression.args.get("replace")
if not replace:
@@ -392,7 +442,7 @@ def _add_replace_columns(expression, tables, replace_columns):
replace_columns[id(table)] = columns
-def _qualify_outputs(scope):
+def _qualify_outputs(scope: Scope):
"""Ensure all output columns are aliased"""
new_selections = []
@@ -429,7 +479,7 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""
- def __init__(self, scope, schema, infer_schema: bool = True):
+ def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 5365aef..34005d9 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -28,6 +28,8 @@ def simplify(expression):
generate = cached_generator()
def _simplify(expression, root=True):
+ if expression.meta.get("final"):
+ return expression
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index e16a88e..e5bd4ae 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -585,6 +585,7 @@ class Parser(metaclass=_Parser):
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self._parse_cluster(),
+ "CLUSTERED": lambda self: self._parse_clustered_by(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
@@ -794,8 +795,6 @@ class Parser(metaclass=_Parser):
# A NULL arg in CONCAT yields NULL by default
CONCAT_NULL_OUTPUTS_STRING = False
- CONVERT_TYPE_FIRST = False
-
PREFIXED_PIVOT_COLUMNS = False
IDENTIFY_PIVOT_STRINGS = False
@@ -1426,9 +1425,34 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
- def _parse_cluster(self) -> t.Optional[exp.Cluster]:
+ def _parse_cluster(self) -> exp.Cluster:
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
+ def _parse_clustered_by(self) -> exp.ClusteredByProperty:
+ self._match_text_seq("BY")
+
+ self._match_l_paren()
+ expressions = self._parse_csv(self._parse_column)
+ self._match_r_paren()
+
+ if self._match_text_seq("SORTED", "BY"):
+ self._match_l_paren()
+ sorted_by = self._parse_csv(self._parse_ordered)
+ self._match_r_paren()
+ else:
+ sorted_by = None
+
+ self._match(TokenType.INTO)
+ buckets = self._parse_number()
+ self._match_text_seq("BUCKETS")
+
+ return self.expression(
+ exp.ClusteredByProperty,
+ expressions=expressions,
+ sorted_by=sorted_by,
+ buckets=buckets,
+ )
+
def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
if not self._match_text_seq("GRANTS"):
self._retreat(self._index - 1)
@@ -2863,7 +2887,11 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.INTERVAL):
return None
- this = self._parse_primary() or self._parse_term()
+ if self._match(TokenType.STRING, advance=False):
+ this = self._parse_primary()
+ else:
+ this = self._parse_term()
+
unit = self._parse_function() or self._parse_var()
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
@@ -3661,6 +3689,7 @@ class Parser(metaclass=_Parser):
else:
self.raise_error("Expected AS after CAST")
+ fmt = None
to = self._parse_types()
if not to:
@@ -3668,22 +3697,23 @@ class Parser(metaclass=_Parser):
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
- elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT):
- fmt = self._parse_string()
+ elif self._match(TokenType.FORMAT):
+ fmt = self._parse_at_time_zone(self._parse_string())
- return self.expression(
- exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
- this=this,
- format=exp.Literal.string(
- format_time(
- fmt.this if fmt else "",
- self.FORMAT_MAPPING or self.TIME_MAPPING,
- self.FORMAT_TRIE or self.TIME_TRIE,
- )
- ),
- )
+ if to.this in exp.DataType.TEMPORAL_TYPES:
+ return self.expression(
+ exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
+ this=this,
+ format=exp.Literal.string(
+ format_time(
+ fmt.this if fmt else "",
+ self.FORMAT_MAPPING or self.TIME_MAPPING,
+ self.FORMAT_TRIE or self.TIME_TRIE,
+ )
+ ),
+ )
- return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
def _parse_concat(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
@@ -3704,20 +3734,23 @@ class Parser(metaclass=_Parser):
)
def _parse_string_agg(self) -> exp.Expression:
- expression: t.Optional[exp.Expression]
-
if self._match(TokenType.DISTINCT):
- args = self._parse_csv(self._parse_conjunction)
- expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
+ args: t.List[t.Optional[exp.Expression]] = [
+ self.expression(exp.Distinct, expressions=[self._parse_conjunction()])
+ ]
+ if self._match(TokenType.COMMA):
+ args.extend(self._parse_csv(self._parse_conjunction))
else:
args = self._parse_csv(self._parse_conjunction)
- expression = seq_get(args, 0)
index = self._index
if not self._match(TokenType.R_PAREN):
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
- order = self._parse_order(this=expression)
- return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
+ return self.expression(
+ exp.GroupConcat,
+ this=seq_get(args, 0),
+ separator=self._parse_order(this=seq_get(args, 1)),
+ )
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
@@ -3727,24 +3760,21 @@ class Parser(metaclass=_Parser):
return self.validate_expression(exp.GroupConcat.from_arg_list(args), args)
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
- order = self._parse_order(this=expression)
+ order = self._parse_order(this=seq_get(args, 0))
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
- to: t.Optional[exp.Expression]
this = self._parse_bitwise()
if self._match(TokenType.USING):
- to = self.expression(exp.CharacterSet, this=self._parse_var())
+ to: t.Optional[exp.Expression] = self.expression(
+ exp.CharacterSet, this=self._parse_var()
+ )
elif self._match(TokenType.COMMA):
- to = self._parse_bitwise()
+ to = self._parse_types()
else:
to = None
- # Swap the argument order if needed to produce the correct AST
- if self.CONVERT_TYPE_FIRST:
- this, to = to, this
-
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
@@ -4394,8 +4424,8 @@ class Parser(metaclass=_Parser):
if self._next:
self._advance()
- parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
+ parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
if parser:
actions = ensure_list(parser(self))
@@ -4516,9 +4546,11 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE)
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
- def _parse_set(self) -> exp.Set | exp.Command:
+ def _parse_set(self, unset: bool = False, tag: bool = False) -> exp.Set | exp.Command:
index = self._index
- set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
+ set_ = self.expression(
+ exp.Set, expressions=self._parse_csv(self._parse_set_item), unset=unset, tag=tag
+ )
if self._curr:
self._retreat(index)
@@ -4683,12 +4715,8 @@ class Parser(metaclass=_Parser):
exp.replace_children(this, self._replace_columns_with_dots)
table = this.args.get("table")
this = (
- self.expression(exp.Dot, this=table, expression=this.this)
- if table
- else self.expression(exp.Var, this=this.name)
+ self.expression(exp.Dot, this=table, expression=this.this) if table else this.this
)
- elif isinstance(this, exp.Identifier):
- this = self.expression(exp.Var, this=this.name)
return this
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 4ed7449..f246702 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -91,6 +91,7 @@ class Step:
A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
+ expression = expression.unnest()
with_ = expression.args.get("with")
# CTEs break the mold of scope and introduce themselves to all in the context.
@@ -120,22 +121,25 @@ class Step:
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
- aggregations = []
+ aggregations = set()
next_operand_name = name_sequence("_a_")
def extract_agg_operands(expression):
- for agg in expression.find_all(exp.AggFunc):
+ agg_funcs = tuple(expression.find_all(exp.AggFunc))
+ if agg_funcs:
+ aggregations.add(expression)
+ for agg in agg_funcs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = next_operand_name()
operand.replace(exp.column(operands[operand], quoted=True))
+ return bool(agg_funcs)
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
- aggregations.append(e)
extract_agg_operands(e)
else:
projections.append(e)
@@ -155,22 +159,38 @@ class Step:
having = expression.args.get("having")
if having:
- extract_agg_operands(having)
- aggregate.condition = having.this
+ if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
+ aggregate.condition = exp.column("_h", step.name, quoted=True)
+ else:
+ aggregate.condition = having.this
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
- aggregate.aggregations = aggregations
+ aggregate.aggregations = list(aggregations)
+
# give aggregates names and replace projections with references to them
aggregate.group = {
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
}
+
+ intermediate: t.Dict[str | exp.Expression, str] = {}
+ for k, v in aggregate.group.items():
+ intermediate[v] = k
+ if isinstance(v, exp.Column):
+ intermediate[v.alias_or_name] = k
+
for projection in projections:
- for i, e in aggregate.group.items():
- for child, *_ in projection.walk():
- if child == e:
- child.replace(exp.column(i, step.name))
+ for node, *_ in projection.walk():
+ name = intermediate.get(node)
+ if name:
+ node.replace(exp.column(name, step.name))
+ if aggregate.condition:
+ for node, *_ in aggregate.condition.walk():
+ name = intermediate.get(node) or intermediate.get(node.name)
+ if name:
+ node.replace(exp.column(name, step.name))
+
aggregate.add_dependency(step)
step = aggregate
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index ba72616..1f30f96 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -159,10 +159,11 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import build_scope
- taken_select_names = set(expression.named_selects)
scope = build_scope(expression)
if not scope:
return expression
+
+ taken_select_names = set(expression.named_selects)
taken_source_names = set(scope.selected_sources)
for select in expression.selects: