summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py88
1 files changed, 63 insertions, 25 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 66b3170..5e40cf3 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
- _expand_using(scope, resolver)
+ using_column_tables = _expand_using(scope, resolver)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver)
+ _expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
+ _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
@@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
- joins = list(scope.expression.find_all(exp.Join))
+ joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
- # Mapping of automatically joined column names to source names
+ # Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
for join in joins:
@@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
)
)
- tables = column_tables.setdefault(identifier, [])
+ # Set all values in the dict to None, because we only care about the key ordering
+ tables = column_tables.setdefault(identifier, {})
if table not in tables:
- tables.append(table)
+ tables[table] = None
if join_table not in tables:
- tables.append(join_table)
+ tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
@@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
scope.replace(column, replacement)
+ return column_tables
-def _expand_group_by(scope, resolver):
- group = scope.expression.args.get("group")
- if not group:
- return
+
+def _expand_alias_refs(scope, resolver):
+ selects = {}
# Replace references to select aliases
def transform(node, *_):
@@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
node.set("table", table)
return node
- selects = {s.alias_or_name: s for s in scope.selects}
-
+ if not selects:
+ for s in scope.selects:
+ selects[s.alias_or_name] = s
select = selects.get(node.name)
+
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
@@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
return node
- group.transform(transform, copy=False)
+ for select in scope.expression.selects:
+ select.transform(transform, copy=False)
+
+ for modifier in ("where", "group"):
+ part = scope.expression.args.get(modifier)
+
+ if part:
+ part.transform(transform, copy=False)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
@@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
+
# Determine whether each reference in the order by clause is to a column or an alias.
- for ordered in scope.find_all(exp.Ordered):
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
+ order = scope.expression.args.get("order")
+
+ if order:
+ for ordered in order.expressions:
+ for column in ordered.find_all(exp.Column):
+ if (
+ not column.table
+ and column.parent is not ordered
+ and column.name in resolver.all_columns
+ ):
+ columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
- for having in scope.find_all(exp.Having):
+ having = scope.expression.args.get("having")
+
+ if having:
for column in having.find_all(exp.Column):
if (
not column.table
@@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
-def _expand_stars(scope, resolver):
+def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
+ coalesced_columns = set()
for expression in scope.selects:
if isinstance(expression, exp.Star):
@@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
if columns and "*" not in columns:
table_id = id(table)
for name in columns:
- if name not in except_columns.get(table_id, set()):
+ if name in using_column_tables and table in using_column_tables[name]:
+ if name in coalesced_columns:
+ continue
+
+ coalesced_columns.add(name)
+ tables = using_column_tables[name]
+ coalesce = [exp.column(name, table=table) for table in tables]
+
+ new_selections.append(
+ exp.alias_(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ )
+ )
+ elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)