summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py221
1 files changed, 126 insertions, 95 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 6ac39f0..4a31171 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -1,14 +1,23 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
+from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import OptimizeError
-from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
-from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.schema import ensure_schema
+from sqlglot.helper import case_sensitive, seq_get
+from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.schema import Schema, ensure_schema
-def qualify_columns(expression, schema, expand_laterals=True):
+def qualify_columns(
+ expression: exp.Expression,
+ schema: dict | Schema,
+ expand_alias_refs: bool = True,
+ infer_schema: t.Optional[bool] = None,
+) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True):
'SELECT tbl.col AS col FROM tbl'
Args:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
+ expression: expression to qualify
+ schema: Database schema
+ expand_alias_refs: whether or not to expand references to aliases
+ infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)
-
- if not schema.mapping and expand_laterals:
- expression = _expand_laterals(expression)
+ infer_schema = schema.empty if infer_schema is None else infer_schema
for scope in traverse_scope(expression):
- resolver = Resolver(scope, schema)
+ resolver = Resolver(scope, schema, infer_schema=infer_schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)
+
+ if schema.empty and expand_alias_refs:
+ _expand_alias_refs(scope, resolver)
+
_qualify_columns(scope, resolver)
+
+ if not schema.empty and expand_alias_refs:
+ _expand_alias_refs(scope, resolver)
+
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
- _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
- if schema.mapping and expand_laterals:
- expression = _expand_laterals(expression)
-
return expression
@@ -55,9 +68,11 @@ def validate_qualify_columns(expression):
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
- if scope.external_columns and not scope.is_correlated_subquery:
+ if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
- raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
+ raise OptimizeError(
+ f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
+ )
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
@@ -142,52 +157,48 @@ def _expand_using(scope, resolver):
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
- replacement = exp.alias_(replacement, alias=column.name)
+ replacement = alias(replacement, alias=column.name, copy=False)
scope.replace(column, replacement)
return column_tables
-def _expand_alias_refs(scope, resolver):
- selects = {}
-
- # Replace references to select aliases
- def transform(node, source_first=True):
- if isinstance(node, exp.Column) and not node.table:
- table = resolver.get_table(node.name)
+def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
+ expression = scope.expression
- # Source columns get priority over select aliases
- if source_first and table:
- node.set("table", table)
- return node
+ if not isinstance(expression, exp.Select):
+ return
- if not selects:
- for s in scope.selects:
- selects[s.alias_or_name] = s
- select = selects.get(node.name)
+ alias_to_expression: t.Dict[str, exp.Expression] = {}
- if select:
- scope.clear_cache()
- if isinstance(select, exp.Alias):
- select = select.this
- return select.copy()
+ def replace_columns(
+ node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
+ ):
+ if not node:
+ return
- node.set("table", table)
- elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
- exp.replace_children(node, transform, source_first)
+ for column, *_ in walk_in_scope(node):
+ if not isinstance(column, exp.Column):
+ continue
+ table = resolver.get_table(column.name) if resolve_agg and not column.table else None
+ if table and column.find_ancestor(exp.AggFunc):
+ column.set("table", table)
+ elif expand and not column.table and column.name in alias_to_expression:
+ column.replace(alias_to_expression[column.name].copy())
- return node
+ for projection in scope.selects:
+ replace_columns(projection)
- for select in scope.expression.selects:
- transform(select)
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
- for modifier, source_first in (
- ("where", True),
- ("group", True),
- ("having", False),
- ):
- transform(scope.expression.args.get(modifier), source_first=source_first)
+ replace_columns(expression.args.get("where"))
+ replace_columns(expression.args.get("group"))
+ replace_columns(expression.args.get("having"), resolve_agg=True)
+ replace_columns(expression.args.get("qualify"), resolve_agg=True)
+ replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
+ scope.clear_cache()
def _expand_group_by(scope, resolver):
@@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
+ if scope.pivots and not column.find_ancestor(exp.Pivot):
+ # If the column is under the Pivot expression, we need to qualify it
+ # using the name of the pivoted source instead of the pivot's alias
+ column.set("table", exp.to_identifier(scope.pivots[0].alias))
+ continue
+
column_table = resolver.get_table(column_name)
# column_table can be a '' because bigquery unnest has no table alias
@@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
- columns_missing_from_scope = []
-
- # Determine whether each reference in the order by clause is to a column or an alias.
- order = scope.expression.args.get("order")
-
- if order:
- for ordered in order.expressions:
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
-
- # Determine whether each reference in the having clause is to a column or an alias.
- having = scope.expression.args.get("having")
-
- if having:
- for column in having.find_all(exp.Column):
- if (
- not column.table
- and column.find_ancestor(exp.AggFunc)
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
-
- for column in columns_missing_from_scope:
- column_table = resolver.get_table(column.name)
-
- if column_table:
- column.set("table", column_table)
+ for pivot in scope.pivots:
+ for column in pivot.find_all(exp.Column):
+ if not column.table and column.name in resolver.all_columns:
+ column_table = resolver.get_table(column.name)
+ if column_table:
+ column.set("table", column_table)
def _expand_stars(scope, resolver, using_column_tables):
@@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables):
replace_columns = {}
coalesced_columns = set()
+ # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
+ pivot_columns = None
+ pivot_output_columns = None
+ pivot = seq_get(scope.pivots, 0)
+
+ has_pivoted_source = pivot and not pivot.args.get("unpivot")
+ if has_pivoted_source:
+ pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
+
+ pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
+ if not pivot_output_columns:
+ pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
+
for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
@@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
+
columns = resolver.get_source_columns(table, only_visible=True)
if columns and "*" not in columns:
+ if has_pivoted_source:
+ implicit_columns = [col for col in columns if col not in pivot_columns]
+ new_selections.extend(
+ exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
+ for name in implicit_columns + pivot_output_columns
+ )
+ continue
+
table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
@@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables):
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
- exp.alias_(
- exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ alias(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
+ alias=name,
+ copy=False,
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
- column = exp.column(name, table)
- new_selections.append(alias(column, alias_) if alias_ != name else column)
+ column = exp.column(name, table=table)
+ new_selections.append(
+ alias(column, alias_, copy=False) if alias_ != name else column
+ )
else:
return
+
scope.expression.set("expressions", new_selections)
@@ -388,9 +406,6 @@ def _qualify_outputs(scope):
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
- quoted=True
- if isinstance(selection, exp.Column) and selection.this.quoted
- else None,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
@@ -400,6 +415,23 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
+def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
+ """Makes sure all identifiers that need to be quoted are quoted."""
+
+ def _quote(expression: E) -> E:
+ if isinstance(expression, exp.Identifier):
+ name = expression.this
+ expression.set(
+ "quoted",
+ identify
+ or case_sensitive(name, dialect=dialect)
+ or not exp.SAFE_IDENTIFIER_RE.match(name),
+ )
+ return expression
+
+ return expression.transform(_quote, copy=False)
+
+
class Resolver:
"""
Helper for resolving columns.
@@ -407,12 +439,13 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""
- def __init__(self, scope, schema):
+ def __init__(self, scope, schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
- self._unambiguous_columns = None
+ self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
self._all_columns = None
+ self._infer_schema = infer_schema
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
"""
@@ -430,7 +463,7 @@ class Resolver:
table_name = self._unambiguous_columns.get(column_name)
- if not table_name:
+ if not table_name and self._infer_schema:
sources_without_schema = tuple(
source
for source, columns in self._get_all_source_columns().items()
@@ -450,11 +483,9 @@ class Resolver:
node_alias = node.args.get("alias")
if node_alias:
- return node_alias.this
+ return exp.to_identifier(node_alias.this)
- return exp.to_identifier(
- table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
- )
+ return exp.to_identifier(table_name)
@property
def all_columns(self):