summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py3
-rw-r--r--sqlglot/optimizer/merge_subqueries.py2
-rw-r--r--sqlglot/optimizer/pushdown_projections.py3
-rw-r--r--sqlglot/optimizer/qualify_columns.py15
-rw-r--r--sqlglot/optimizer/scope.py32
5 files changed, 37 insertions, 18 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 5ae1fa0..728493d 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -144,8 +144,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
- parent.replace(table)
+ table.set("joins", parent.args.get("joins"))
+ parent.replace(table)
return cte
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 6ee057b..7322424 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -176,6 +176,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
return (
isinstance(outer_scope.expression, exp.Select)
+ and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
@@ -242,6 +243,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
alias (str)
"""
new_subquery = inner_scope.expression.args["from"].this
+ new_subquery.set("joins", node_to_replace.args.get("joins"))
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 97e8ff6..c81fd00 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -61,6 +61,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if remove_unused_selections:
_remove_unused_selections(scope, parent_selections, schema)
+ if scope.expression.is_star:
+ continue
+
# Group columns by source name
selects = defaultdict(set)
for col in scope.columns:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 7972b2b..2657188 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -29,12 +29,13 @@ def qualify_columns(
'SELECT tbl.col AS col FROM tbl'
Args:
- expression: expression to qualify
- schema: Database schema
- expand_alias_refs: whether or not to expand references to aliases
- infer_schema: whether or not to infer the schema if missing
+ expression: Expression to qualify.
+ schema: Database schema.
+ expand_alias_refs: Whether or not to expand references to aliases.
+ infer_schema: Whether or not to infer the schema if missing.
+
Returns:
- sqlglot.Expression: qualified expression
+ The qualified expression.
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
@@ -410,7 +411,9 @@ def _expand_stars(
else:
return
- scope.expression.set("expressions", new_selections)
+ # Ensures we don't overwrite the initial selections with an empty list
+ if new_selections:
+ scope.expression.set("expressions", new_selections)
def _add_except_columns(
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index b2b4230..a7dab35 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -124,8 +124,8 @@ class Scope:
self._ctes.append(node)
elif (
isinstance(node, exp.Subquery)
- and isinstance(parent, (exp.From, exp.Join))
- and _is_subquery_scope(node)
+ and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
+ and _is_derived_table(node)
):
self._derived_tables.append(node)
elif isinstance(node, exp.Subqueryable):
@@ -610,13 +610,13 @@ def _traverse_ctes(scope):
scope.sources.update(sources)
-def _is_subquery_scope(expression: exp.Subquery) -> bool:
+def _is_derived_table(expression: exp.Subquery) -> bool:
"""
- We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope.
- If an alias is present, it shadows all names under the Subquery, so that's an
- exception to this rule.
+ We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
+ as it doesn't introduce a new scope. If an alias is present, it shadows all names
+ under the Subquery, so that's one exception to this rule.
"""
- return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias)
+ return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
def _traverse_tables(scope):
@@ -654,7 +654,10 @@ def _traverse_tables(scope):
else:
sources[source_name] = expression
- expressions.extend(join.this for join in expression.args.get("joins") or [])
+ # Make sure to not include the joins twice
+ if expression is not scope.expression:
+ expressions.extend(join.this for join in expression.args.get("joins") or [])
+
continue
if not isinstance(expression, exp.DerivedTable):
@@ -664,10 +667,11 @@ def _traverse_tables(scope):
lateral_sources = sources
scope_type = ScopeType.UDTF
scopes = scope.udtf_scopes
- elif _is_subquery_scope(expression):
+ elif _is_derived_table(expression):
lateral_sources = None
scope_type = ScopeType.DERIVED_TABLE
scopes = scope.derived_table_scopes
+ expressions.extend(join.this for join in expression.args.get("joins") or [])
else:
# Makes sure we check for possible sources in nested table constructs
expressions.append(expression.this)
@@ -735,10 +739,16 @@ def walk_in_scope(expression, bfs=True):
isinstance(node, exp.CTE)
or (
isinstance(node, exp.Subquery)
- and isinstance(parent, (exp.From, exp.Join))
- and _is_subquery_scope(node)
+ and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
+ and _is_derived_table(node)
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
prune = True
+
+ if isinstance(node, (exp.Subquery, exp.UDTF)):
+ # The following args are not actually in the inner scope, so we should visit them
+ for key in ("joins", "laterals", "pivots"):
+ for arg in node.args.get(key) or []:
+ yield from walk_in_scope(arg, bfs=bfs)