summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py170
1 files changed, 124 insertions, 46 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 742cdf5..a6397ae 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -17,6 +17,7 @@ def qualify_columns(
expression: exp.Expression,
schema: t.Dict | Schema,
expand_alias_refs: bool = True,
+ expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
"""
@@ -33,10 +34,16 @@ def qualify_columns(
expression: Expression to qualify.
schema: Database schema.
expand_alias_refs: Whether or not to expand references to aliases.
+ expand_stars: Whether or not to expand star queries. This is a necessary step
+ for most of the optimizer's rules to work; do not set to False unless you
+ know what you're doing!
infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
+
+ Notes:
+ - Currently only handles a single PIVOT or UNPIVOT operator
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
@@ -57,7 +64,8 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
+ if expand_stars:
+ _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
qualify_outputs(scope)
_expand_group_by(scope)
@@ -68,21 +76,41 @@ def qualify_columns(
def validate_qualify_columns(expression: E) -> E:
"""Raise an `OptimizeError` if any columns aren't qualified"""
- unqualified_columns = []
+ all_unqualified_columns = []
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
- unqualified_columns.extend(scope.unqualified_columns)
+ unqualified_columns = scope.unqualified_columns
+
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
- raise OptimizeError(
- f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
- )
+ for_table = f" for table: '{column.table}'" if column.table else ""
+ raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
+
+ if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
+ # New columns produced by the UNPIVOT can't be qualified, but there may be columns
+ # under the UNPIVOT's IN clause that can and should be qualified. We recompute
+ # this list here to ensure those in the former category will be excluded.
+ unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
+ unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
+
+ all_unqualified_columns.extend(unqualified_columns)
+
+ if all_unqualified_columns:
+ raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
- if unqualified_columns:
- raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
return expression
+def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
+ name_column = []
+ field = unpivot.args.get("field")
+ if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
+ name_column.append(field.this)
+
+ value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
+ return itertools.chain(name_column, value_columns)
+
+
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
"""
Remove table column aliases.
@@ -216,6 +244,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
+
scope.clear_cache()
@@ -353,18 +382,25 @@ def _expand_stars(
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()
- # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
- pivot_columns = None
pivot_output_columns = None
- pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
+ pivot_exclude_columns = None
- has_pivoted_source = pivot and not pivot.args.get("unpivot")
- if pivot and has_pivoted_source:
- pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
+ pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
+ if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
+ if pivot.unpivot:
+ pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
+
+ field = pivot.args.get("field")
+ if isinstance(field, exp.In):
+ pivot_exclude_columns = {
+ c.output_name for e in field.expressions for c in e.find_all(exp.Column)
+ }
+ else:
+ pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
- pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
- if not pivot_output_columns:
- pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
+ pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
+ if not pivot_output_columns:
+ pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
for expression in scope.expression.selects:
if isinstance(expression, exp.Star):
@@ -384,47 +420,54 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
+ columns = columns or scope.outer_column_list
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
- if columns and "*" not in columns:
- table_id = id(table)
- columns_to_exclude = except_columns.get(table_id) or set()
+ if not columns or "*" in columns:
+ return
+
+ table_id = id(table)
+ columns_to_exclude = except_columns.get(table_id) or set()
- if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
- implicit_columns = [col for col in columns if col not in pivot_columns]
+ if pivot:
+ if pivot_output_columns and pivot_exclude_columns:
+ pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
+ pivot_columns.extend(pivot_output_columns)
+ else:
+ pivot_columns = pivot.alias_column_names
+
+ if pivot_columns:
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
- for name in implicit_columns + pivot_output_columns
+ for name in pivot_columns
if name not in columns_to_exclude
)
continue
- for name in columns:
- if name in using_column_tables and table in using_column_tables[name]:
- if name in coalesced_columns:
- continue
-
- coalesced_columns.add(name)
- tables = using_column_tables[name]
- coalesce = [exp.column(name, table=table) for table in tables]
-
- new_selections.append(
- alias(
- exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
- alias=name,
- copy=False,
- )
- )
- elif name not in columns_to_exclude:
- alias_ = replace_columns.get(table_id, {}).get(name, name)
- column = exp.column(name, table=table)
- new_selections.append(
- alias(column, alias_, copy=False) if alias_ != name else column
+ for name in columns:
+ if name in using_column_tables and table in using_column_tables[name]:
+ if name in coalesced_columns:
+ continue
+
+ coalesced_columns.add(name)
+ tables = using_column_tables[name]
+ coalesce = [exp.column(name, table=table) for table in tables]
+
+ new_selections.append(
+ alias(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
+ alias=name,
+ copy=False,
)
- else:
- return
+ )
+ elif name not in columns_to_exclude:
+ alias_ = replace_columns.get(table_id, {}).get(name, name)
+ column = exp.column(name, table=table)
+ new_selections.append(
+ alias(column, alias_, copy=False) if alias_ != name else column
+ )
# Ensures we don't overwrite the initial selections with an empty list
if new_selections:
@@ -472,6 +515,9 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
):
+ if selection is None:
+ break
+
if isinstance(selection, exp.Subquery):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
@@ -495,6 +541,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
)
+def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Pushes down the CTE alias columns into the projection,
+
+ This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
+ >>> pushdown_cte_alias_columns(expression).sql()
+ 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
+
+ Args:
+ expression: Expression to pushdown.
+
+ Returns:
+ The expression with the CTE aliases pushed down into the projection.
+ """
+ for cte in expression.find_all(exp.CTE):
+ if cte.alias_column_names:
+ new_expressions = []
+ for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
+ if isinstance(projection, exp.Alias):
+ projection.set("alias", _alias)
+ else:
+ projection = alias(projection, alias=_alias)
+ new_expressions.append(projection)
+ cte.this.set("expressions", new_expressions)
+
+ return expression
+
+
class Resolver:
"""
Helper for resolving columns.