diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 170 |
1 files changed, 124 insertions, 46 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 742cdf5..a6397ae 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -17,6 +17,7 @@ def qualify_columns( expression: exp.Expression, schema: t.Dict | Schema, expand_alias_refs: bool = True, + expand_stars: bool = True, infer_schema: t.Optional[bool] = None, ) -> exp.Expression: """ @@ -33,10 +34,16 @@ def qualify_columns( expression: Expression to qualify. schema: Database schema. expand_alias_refs: Whether or not to expand references to aliases. + expand_stars: Whether or not to expand star queries. This is a necessary step + for most of the optimizer's rules to work; do not set to False unless you + know what you're doing! infer_schema: Whether or not to infer the schema if missing. Returns: The qualified expression. + + Notes: + - Currently only handles a single PIVOT or UNPIVOT operator """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema @@ -57,7 +64,8 @@ def qualify_columns( _expand_alias_refs(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver, using_column_tables, pseudocolumns) + if expand_stars: + _expand_stars(scope, resolver, using_column_tables, pseudocolumns) qualify_outputs(scope) _expand_group_by(scope) @@ -68,21 +76,41 @@ def qualify_columns( def validate_qualify_columns(expression: E) -> E: """Raise an `OptimizeError` if any columns aren't qualified""" - unqualified_columns = [] + all_unqualified_columns = [] for scope in traverse_scope(expression): if isinstance(scope.expression, exp.Select): - unqualified_columns.extend(scope.unqualified_columns) + unqualified_columns = scope.unqualified_columns + if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: column = scope.external_columns[0] - raise OptimizeError( - f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" - ) + for_table = f" for table: '{column.table}'" if column.table else "" + raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") + + if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: + # New columns produced by the UNPIVOT can't be qualified, but there may be columns + # under the UNPIVOT's IN clause that can and should be qualified. We recompute + # this list here to ensure those in the former category will be excluded. + unpivot_columns = set(_unpivot_columns(scope.pivots[0])) + unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] + + all_unqualified_columns.extend(unqualified_columns) + + if all_unqualified_columns: + raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") - if unqualified_columns: - raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") return expression +def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: + name_column = [] + field = unpivot.args.get("field") + if isinstance(field, exp.In) and isinstance(field.this, exp.Column): + name_column.append(field.this) + + value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) + return itertools.chain(name_column, value_columns) + + def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: """ Remove table column aliases. @@ -216,6 +244,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: replace_columns(expression.args.get("group"), literal_index=True) replace_columns(expression.args.get("having"), resolve_table=True) replace_columns(expression.args.get("qualify"), resolve_table=True) + scope.clear_cache() @@ -353,18 +382,25 @@ def _expand_stars( replace_columns: t.Dict[int, t.Dict[str, str]] = {} coalesced_columns = set() - # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future - pivot_columns = None pivot_output_columns = None - pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) + pivot_exclude_columns = None - has_pivoted_source = pivot and not pivot.args.get("unpivot") - if pivot and has_pivoted_source: - pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) + pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) + if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: + if pivot.unpivot: + pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] + + field = pivot.args.get("field") + if isinstance(field, exp.In): + pivot_exclude_columns = { + c.output_name for e in field.expressions for c in e.find_all(exp.Column) + } + else: + pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) - pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] - if not pivot_output_columns: - pivot_output_columns = [col.alias_or_name for col in pivot.expressions] + pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] + if not pivot_output_columns: + pivot_output_columns = [c.alias_or_name for c in pivot.expressions] for expression in scope.expression.selects: if isinstance(expression, exp.Star): @@ -384,47 +420,54 @@ def _expand_stars( raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) + columns = columns or scope.outer_column_list if pseudocolumns: columns = [name for name in columns if name.upper() not in pseudocolumns] - if columns and "*" not in columns: - table_id = id(table) - columns_to_exclude = except_columns.get(table_id) or set() + if not columns or "*" in columns: + return + + table_id = id(table) + columns_to_exclude = except_columns.get(table_id) or set() - if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: - implicit_columns = [col for col in columns if col not in pivot_columns] + if pivot: + if pivot_output_columns and pivot_exclude_columns: + pivot_columns = [c for c in columns if c not in pivot_exclude_columns] + pivot_columns.extend(pivot_output_columns) + else: + pivot_columns = pivot.alias_column_names + + if pivot_columns: new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) - for name in implicit_columns + pivot_output_columns + for name in pivot_columns if name not in columns_to_exclude ) continue - for name in columns: - if name in using_column_tables and table in using_column_tables[name]: - if name in coalesced_columns: - continue - - coalesced_columns.add(name) - tables = using_column_tables[name] - coalesce = [exp.column(name, table=table) for table in tables] - - new_selections.append( - alias( - exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), - alias=name, - copy=False, - ) - ) - elif name not in columns_to_exclude: - alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table=table) - new_selections.append( - alias(column, alias_, copy=False) if alias_ != name else column + for name in columns: + if name in using_column_tables and table in using_column_tables[name]: + if name in coalesced_columns: + continue + + coalesced_columns.add(name) + tables = using_column_tables[name] + coalesce = [exp.column(name, table=table) for table in tables] + + new_selections.append( + alias( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), + alias=name, + copy=False, ) - else: - return + ) + elif name not in columns_to_exclude: + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table=table) + new_selections.append( + alias(column, alias_, copy=False) if alias_ != name else column + ) # Ensures we don't overwrite the initial selections with an empty list if new_selections: @@ -472,6 +515,9 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): + if selection is None: + break + if isinstance(selection, exp.Subquery): if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) @@ -495,6 +541,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool ) +def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: + """ + Pushes down the CTE alias columns into the projection, + + This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") + >>> pushdown_cte_alias_columns(expression).sql() + 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' + + Args: + expression: Expression to pushdown. + + Returns: + The expression with the CTE aliases pushed down into the projection. + """ + for cte in expression.find_all(exp.CTE): + if cte.alias_column_names: + new_expressions = [] + for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): + if isinstance(projection, exp.Alias): + projection.set("alias", _alias) + else: + projection = alias(projection, alias=_alias) + new_expressions.append(projection) + cte.this.set("expressions", new_expressions) + + return expression + + class Resolver: """ Helper for resolving columns. |