diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 114 |
1 files changed, 82 insertions, 32 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ac8eb0f..ef8aeb1 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -56,13 +56,13 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) - _expand_group_by(scope, resolver) - _expand_order_by(scope) + _expand_group_by(scope) + _expand_order_by(scope, resolver) return expression -def validate_qualify_columns(expression): +def validate_qualify_columns(expression: E) -> E: """Raise an `OptimizeError` if any columns aren't qualified""" unqualified_columns = [] for scope in traverse_scope(expression): @@ -79,7 +79,7 @@ def validate_qualify_columns(expression): return expression -def _pop_table_column_aliases(derived_tables): +def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: """ Remove table column aliases. @@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables): table_alias.args.pop("columns", None) -def _expand_using(scope, resolver): +def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: joins = list(scope.find_all(exp.Join)) names = {join.alias_or_name for join in joins} ordered = [key for key in scope.selected_sources if key not in names] # Mapping of automatically joined column names to an ordered set of source names (dict). - column_tables = {} + column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} for join in joins: using = join.args.get("using") @@ -172,20 +172,25 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: alias_to_expression: t.Dict[str, exp.Expression] = {} - def replace_columns( - node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False - ): + def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None: if not node: return for column, *_ in walk_in_scope(node): if not isinstance(column, exp.Column): continue - table = resolver.get_table(column.name) if resolve_agg and not column.table else None - if table and column.find_ancestor(exp.AggFunc): + table = resolver.get_table(column.name) if resolve_table and not column.table else None + alias_expr = alias_to_expression.get(column.name) + double_agg = ( + (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) + if alias_expr + else False + ) + + if table and (not alias_expr or double_agg): column.set("table", table) - elif expand and not column.table and column.name in alias_to_expression: - column.replace(alias_to_expression[column.name].copy()) + elif not column.table and alias_expr and not double_agg: + column.replace(alias_expr.copy()) for projection in scope.selects: replace_columns(projection) @@ -195,22 +200,41 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: replace_columns(expression.args.get("where")) replace_columns(expression.args.get("group")) - replace_columns(expression.args.get("having"), resolve_agg=True) - replace_columns(expression.args.get("qualify"), resolve_agg=True) - replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) + replace_columns(expression.args.get("having"), resolve_table=True) + replace_columns(expression.args.get("qualify"), resolve_table=True) scope.clear_cache() -def _expand_group_by(scope, resolver): - group = scope.expression.args.get("group") +def _expand_group_by(scope: Scope): + expression = scope.expression + group = expression.args.get("group") if not group: return group.set("expressions", _expand_positional_references(scope, group.expressions)) - scope.expression.set("group", group) + expression.set("group", group) + + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + groups = set(group.expressions) + group.meta["final"] = True + + for e in expression.selects: + for node, *_ in e.walk(): + if node in groups: + e.meta["final"] = True + break + having = expression.args.get("having") + if having: + for node, *_ in having.walk(): + if node in groups: + having.meta["final"] = True + break -def _expand_order_by(scope): + +def _expand_order_by(scope: Scope, resolver: Resolver): order = scope.expression.args.get("order") if not order: return @@ -220,10 +244,21 @@ def _expand_order_by(scope): ordereds, _expand_positional_references(scope, (o.this for o in ordereds)), ): + for agg in ordered.find_all(exp.AggFunc): + for col in agg.find_all(exp.Column): + if not col.table: + col.set("table", resolver.get_table(col.name)) + ordered.set("this", new_expression) + if scope.expression.args.get("group"): + selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} + + for ordered in ordereds: + ordered.set("this", selects.get(ordered.this, ordered.this)) -def _expand_positional_references(scope, expressions): + +def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: new_nodes = [] for node in expressions: if node.is_int: @@ -241,7 +276,7 @@ def _expand_positional_references(scope, expressions): return new_nodes -def _qualify_columns(scope, resolver): +def _qualify_columns(scope: Scope, resolver: Resolver) -> None: """Disambiguate columns, ensuring each column specifies a source""" for column in scope.columns: column_table = column.table @@ -290,21 +325,23 @@ def _qualify_columns(scope, resolver): column.set("table", column_table) -def _expand_stars(scope, resolver, using_column_tables): +def _expand_stars( + scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any] +) -> None: """Expand stars to lists of column selections""" new_selections = [] - except_columns = {} - replace_columns = {} + except_columns: t.Dict[int, t.Set[str]] = {} + replace_columns: t.Dict[int, t.Dict[str, str]] = {} coalesced_columns = set() # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future pivot_columns = None pivot_output_columns = None - pivot = seq_get(scope.pivots, 0) + pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) has_pivoted_source = pivot and not pivot.args.get("unpivot") - if has_pivoted_source: + if pivot and has_pivoted_source: pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] @@ -330,8 +367,17 @@ def _expand_stars(scope, resolver, using_column_tables): columns = resolver.get_source_columns(table, only_visible=True) + # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement + # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table + if resolver.schema.dialect == "bigquery": + columns = [ + name + for name in columns + if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE") + ] + if columns and "*" not in columns: - if has_pivoted_source: + if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: implicit_columns = [col for col in columns if col not in pivot_columns] new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) @@ -368,7 +414,9 @@ def _expand_stars(scope, resolver, using_column_tables): scope.expression.set("expressions", new_selections) -def _add_except_columns(expression, tables, except_columns): +def _add_except_columns( + expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] +) -> None: except_ = expression.args.get("except") if not except_: @@ -380,7 +428,9 @@ def _add_except_columns(expression, tables, except_columns): except_columns[id(table)] = columns -def _add_replace_columns(expression, tables, replace_columns): +def _add_replace_columns( + expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] +) -> None: replace = expression.args.get("replace") if not replace: @@ -392,7 +442,7 @@ def _add_replace_columns(expression, tables, replace_columns): replace_columns[id(table)] = columns -def _qualify_outputs(scope): +def _qualify_outputs(scope: Scope): """Ensure all output columns are aliased""" new_selections = [] @@ -429,7 +479,7 @@ class Resolver: This is a class so we can lazily load some things and easily share them across functions. """ - def __init__(self, scope, schema, infer_schema: bool = True): + def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): self.scope = scope self.schema = schema self._source_columns = None |