summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py84
1 files changed, 52 insertions, 32 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 9c34cef..952999d 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -58,6 +59,7 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)
+
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -85,7 +87,7 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
"""
Remove table column aliases.
- (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
+ For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
table_alias = derived_table.args.get("alias")
@@ -111,11 +113,11 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
columns = {}
- for k in scope.selected_sources:
- if k in ordered:
- for column in resolver.get_source_columns(k):
- if column not in columns:
- columns[column] = k
+ for source_name in scope.selected_sources:
+ if source_name in ordered:
+ for column_name in resolver.get_source_columns(source_name):
+ if column_name not in columns:
+ columns[column_name] = source_name
source_table = ordered[-1]
ordered.append(join_table)
@@ -183,6 +185,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
+
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
@@ -198,7 +201,10 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if literal_index:
column.replace(exp.Literal.number(i))
else:
- column.replace(alias_expr.copy())
+ column = column.replace(exp.paren(alias_expr))
+ simplified = simplify_parens(column)
+ if simplified is not column:
+ column.replace(simplified)
for i, projection in enumerate(scope.expression.selects):
replace_columns(projection)
@@ -213,7 +219,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache()
-def _expand_group_by(scope: Scope):
+def _expand_group_by(scope: Scope) -> None:
expression = scope.expression
group = expression.args.get("group")
if not group:
@@ -223,7 +229,7 @@ def _expand_group_by(scope: Scope):
expression.set("group", group)
-def _expand_order_by(scope: Scope, resolver: Resolver):
+def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
order = scope.expression.args.get("order")
if not order:
return
@@ -442,7 +448,7 @@ def _add_replace_columns(
replace_columns[id(table)] = columns
-def _qualify_outputs(scope: Scope):
+def _qualify_outputs(scope: Scope) -> None:
"""Ensure all output columns are aliased"""
new_selections = []
@@ -482,9 +488,9 @@ class Resolver:
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
- self._source_columns = None
+ self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
- self._all_columns = None
+ self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
@@ -528,7 +534,7 @@ class Resolver:
return exp.to_identifier(table_name)
@property
- def all_columns(self):
+ def all_columns(self) -> t.Set[str]:
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = {
@@ -536,53 +542,67 @@ class Resolver:
}
return self._all_columns
- def get_source_columns(self, name, only_visible=False):
- """Resolve the source columns for a given source `name`"""
+ def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
+ """Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
source = self.scope.sources[name]
- # If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
- return self.schema.column_names(source, only_visible)
+ columns = self.schema.column_names(source, only_visible)
+ elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+ columns = source.expression.alias_column_names
+ else:
+ columns = source.expression.named_selects
- if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
- return source.expression.alias_column_names
+ node, _ = self.scope.selected_sources.get(name) or (None, None)
+ if isinstance(node, Scope):
+ column_aliases = node.expression.alias_column_names
+ elif isinstance(node, exp.Expression):
+ column_aliases = node.alias_column_names
+ else:
+ column_aliases = []
- # Otherwise, if referencing another scope, return that scope's named selects
- return source.expression.named_selects
+ # If the source's columns are aliased, their aliases shadow the corresponding column names
+ return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
- def _get_all_source_columns(self):
+ def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
if self._source_columns is None:
self._source_columns = {
- k: self.get_source_columns(k)
- for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
+ source_name: self.get_source_columns(source_name)
+ for source_name, source in itertools.chain(
+ self.scope.selected_sources.items(), self.scope.lateral_sources.items()
+ )
}
return self._source_columns
- def _get_unambiguous_columns(self, source_columns):
+ def _get_unambiguous_columns(
+ self, source_columns: t.Dict[str, t.List[str]]
+ ) -> t.Dict[str, str]:
"""
Find all the unambiguous columns in sources.
Args:
- source_columns (dict): Mapping of names to source columns
+ source_columns: Mapping of names to source columns.
+
Returns:
- dict: Mapping of column name to source name
+ Mapping of column name to source name.
"""
if not source_columns:
return {}
- source_columns = list(source_columns.items())
+ source_columns_pairs = list(source_columns.items())
- first_table, first_columns = source_columns[0]
+ first_table, first_columns = source_columns_pairs[0]
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
all_columns = set(unambiguous_columns)
- for table, columns in source_columns[1:]:
+ for table, columns in source_columns_pairs[1:]:
unique = self._find_unique_columns(columns)
ambiguous = set(all_columns).intersection(unique)
all_columns.update(columns)
+
for column in ambiguous:
unambiguous_columns.pop(column, None)
for column in unique.difference(ambiguous):
@@ -591,7 +611,7 @@ class Resolver:
return unambiguous_columns
@staticmethod
- def _find_unique_columns(columns):
+ def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
"""
Find the unique columns in a list of columns.
@@ -601,7 +621,7 @@ class Resolver:
This is necessary because duplicate column names are ambiguous.
"""
- counts = {}
+ counts: t.Dict[str, int] = {}
for column in columns:
counts[column] = counts.get(column, 0) + 1
return {column for column, count in counts.items() if count == 1}