summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py41
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py6
-rw-r--r--sqlglot/optimizer/qualify_columns.py21
-rw-r--r--sqlglot/optimizer/qualify_tables.py4
-rw-r--r--sqlglot/optimizer/scope.py12
-rw-r--r--sqlglot/optimizer/simplify.py16
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py14
7 files changed, 90 insertions, 24 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index ce274bb..81b1ee6 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateToDi,
exp.Floor,
exp.Levenshtein,
+ exp.Sign,
exp.StrPosition,
exp.TsOrDiToDi,
},
@@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
+ exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
@@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
+ exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
+ exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
- exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
}
NESTED_TYPES = {
@@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
- elif source and col.table in selects and col.name in selects[col.table]:
- self._set_type(col, selects[col.table][col.name].type)
+ elif source:
+ if col.table in selects and col.name in selects[col.table]:
+ self._set_type(col, selects[col.table][col.name].type)
+ elif isinstance(source.expression, exp.Unnest):
+ self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
last_datatype = None
for expr in expressions:
- last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
+ expr_type = expr.type
+
+ # Stop at the first nested data type found - we don't want to _maybe_coerce nested types
+ if expr_type.args.get("nested"):
+ last_datatype = expr_type
+ break
+
+ last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
@@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
+ def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
+ self._annotate_args(expression)
+ self._set_type(expression, None)
+ this_type = expression.this.type
+
+ if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
+ for e in this_type.expressions:
+ if e.name == expression.expression.name:
+ self._set_type(expression, e.kind)
+ break
+
+ return expression
+
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
self._annotate_args(expression)
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
return expression
+
+ def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
+ self._annotate_args(expression)
+ child = seq_get(expression.expressions, 0)
+ self._set_type(expression, child and seq_get(child.type.expressions, 0))
+ return expression
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index d22a998..f2a0990 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -10,13 +10,11 @@ if t.TYPE_CHECKING:
@t.overload
-def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
- ...
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
- ...
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ef589c9..233ffc9 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
+ if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
+ continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
@@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
- (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
+ (
+ alias_expr.find(exp.AggFunc)
+ and (
+ column.find_ancestor(exp.AggFunc)
+ and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
+ )
+ )
if alias_expr
else False
)
@@ -404,7 +412,7 @@ def _expand_stars(
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
- elif expression.is_star:
+ elif expression.is_star and not isinstance(expression, exp.Dot):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
@@ -437,7 +445,7 @@ def _expand_stars(
if pivot_columns:
new_selections.extend(
- exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
+ alias(exp.column(name, table=pivot.alias), name, copy=False)
for name in pivot_columns
if name not in columns_to_exclude
)
@@ -466,7 +474,7 @@ def _expand_stars(
)
# Ensures we don't overwrite the initial selections with an empty list
- if new_selections:
+ if new_selections and isinstance(scope.expression, exp.Select):
scope.expression.set("expressions", new_selections)
@@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections.append(selection)
- scope.expression.set("expressions", new_selections)
+ if isinstance(scope.expression, exp.Select):
+ scope.expression.set("expressions", new_selections)
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
@@ -615,7 +624,7 @@ class Resolver:
node, _ = self.scope.selected_sources.get(table_name)
- if isinstance(node, exp.Subqueryable):
+ if isinstance(node, exp.Query):
while node and node.alias != table_name:
node = node.parent
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index d460e81..214ac0a 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -55,8 +55,8 @@ def qualify_tables(
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
- if not isinstance(expression, exp.Subqueryable):
- for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
+ if not isinstance(expression, exp.Query):
+ for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 0eae979..443fa6c 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -138,7 +138,7 @@ class Scope:
and _is_derived_table(node)
):
self._derived_tables.append(node)
- elif isinstance(node, exp.Subqueryable):
+ elif isinstance(node, exp.UNWRAPPED_QUERIES):
self._subqueries.append(node)
self._collected = True
@@ -225,7 +225,7 @@ class Scope:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
- list[exp.Subqueryable]: subqueries
+ list[exp.Select | exp.Union]: subqueries
"""
self._ensure_collected()
return self._subqueries
@@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
- if isinstance(expression, exp.Unionable) or (
- isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
+ if isinstance(expression, exp.Query) or (
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
):
return list(_traverse_scope(Scope(expression)))
@@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
- return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
+ return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
def _traverse_tables(scope):
@@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None):
and _is_derived_table(node)
)
or isinstance(node, exp.UDTF)
- or isinstance(node, exp.Subqueryable)
+ or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 9ffddb5..2e43d21 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str:
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
- exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
+ exp.Anonymous: lambda e: _anonymous(e),
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
@@ -1219,6 +1219,20 @@ GEN_MAP = {
}
+def _anonymous(e: exp.Anonymous) -> str:
+ this = e.this
+ if isinstance(this, str):
+ name = this.upper()
+ elif isinstance(this, exp.Identifier):
+ name = f'"{this.name}"' if this.quoted else this.name.upper()
+ else:
+ raise ValueError(
+ f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
+ )
+
+ return f"{name} {','.join(gen(e) for e in e.expressions)}"
+
+
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index b4c7475..36d9da4 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name):
else:
_replace(predicate, join_key_not_null)
+ group = select.args.get("group")
+
+ if group:
+ if {value.this} != set(group.expressions):
+ select = (
+ exp.select(exp.column(value.alias, "_q"))
+ .from_(select.subquery("_q", copy=False), copy=False)
+ .group_by(exp.column(value.alias, "_q"), copy=False)
+ )
+ else:
+ select = select.group_by(value.this, copy=False)
+
parent_select.join(
- select.group_by(value.this, copy=False),
+ select,
on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,