diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 84 |
1 files changed, 52 insertions, 32 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 9c34cef..952999d 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -58,6 +59,7 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables, pseudocolumns) _qualify_outputs(scope) + _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -85,7 +87,7 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> """ Remove table column aliases. - (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) + For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: table_alias = derived_table.args.get("alias") @@ -111,11 +113,11 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: columns = {} - for k in scope.selected_sources: - if k in ordered: - for column in resolver.get_source_columns(k): - if column not in columns: - columns[column] = k + for source_name in scope.selected_sources: + if source_name in ordered: + for column_name in resolver.get_source_columns(source_name): + if column_name not in columns: + columns[column_name] = source_name source_table = ordered[-1] ordered.append(join_table) @@ -183,6 +185,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: for column, *_ in walk_in_scope(node): if not isinstance(column, exp.Column): continue + table = resolver.get_table(column.name) if resolve_table and not column.table else None alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( @@ -198,7 +201,10 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if literal_index: column.replace(exp.Literal.number(i)) else: - column.replace(alias_expr.copy()) + column = column.replace(exp.paren(alias_expr)) + simplified = simplify_parens(column) + if simplified is not column: + column.replace(simplified) for i, projection in enumerate(scope.expression.selects): replace_columns(projection) @@ -213,7 +219,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: scope.clear_cache() -def _expand_group_by(scope: Scope): +def _expand_group_by(scope: Scope) -> None: expression = scope.expression group = expression.args.get("group") if not group: @@ -223,7 +229,7 @@ def _expand_group_by(scope: Scope): expression.set("group", group) -def _expand_order_by(scope: Scope, resolver: Resolver): +def _expand_order_by(scope: Scope, resolver: Resolver) -> None: order = scope.expression.args.get("order") if not order: return @@ -442,7 +448,7 @@ def _add_replace_columns( replace_columns[id(table)] = columns -def _qualify_outputs(scope: Scope): +def _qualify_outputs(scope: Scope) -> None: """Ensure all output columns are aliased""" new_selections = [] @@ -482,9 +488,9 @@ class Resolver: def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): self.scope = scope self.schema = schema - self._source_columns = None + self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None - self._all_columns = None + self._all_columns: t.Optional[t.Set[str]] = None self._infer_schema = infer_schema def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: @@ -528,7 +534,7 @@ class Resolver: return exp.to_identifier(table_name) @property - def all_columns(self): + def all_columns(self) -> t.Set[str]: """All available columns of all sources in this scope""" if self._all_columns is None: self._all_columns = { @@ -536,53 +542,67 @@ class Resolver: } return self._all_columns - def get_source_columns(self, name, only_visible=False): - """Resolve the source columns for a given source `name`""" + def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: + """Resolve the source columns for a given source `name`.""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") source = self.scope.sources[name] - # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): - return self.schema.column_names(source, only_visible) + columns = self.schema.column_names(source, only_visible) + elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): + columns = source.expression.alias_column_names + else: + columns = source.expression.named_selects - if isinstance(source, Scope) and isinstance(source.expression, exp.Values): - return source.expression.alias_column_names + node, _ = self.scope.selected_sources.get(name) or (None, None) + if isinstance(node, Scope): + column_aliases = node.expression.alias_column_names + elif isinstance(node, exp.Expression): + column_aliases = node.alias_column_names + else: + column_aliases = [] - # Otherwise, if referencing another scope, return that scope's named selects - return source.expression.named_selects + # If the source's columns are aliased, their aliases shadow the corresponding column names + return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] - def _get_all_source_columns(self): + def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: if self._source_columns is None: self._source_columns = { - k: self.get_source_columns(k) - for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) + source_name: self.get_source_columns(source_name) + for source_name, source in itertools.chain( + self.scope.selected_sources.items(), self.scope.lateral_sources.items() + ) } return self._source_columns - def _get_unambiguous_columns(self, source_columns): + def _get_unambiguous_columns( + self, source_columns: t.Dict[str, t.List[str]] + ) -> t.Dict[str, str]: """ Find all the unambiguous columns in sources. Args: - source_columns (dict): Mapping of names to source columns + source_columns: Mapping of names to source columns. + Returns: - dict: Mapping of column name to source name + Mapping of column name to source name. """ if not source_columns: return {} - source_columns = list(source_columns.items()) + source_columns_pairs = list(source_columns.items()) - first_table, first_columns = source_columns[0] + first_table, first_columns = source_columns_pairs[0] unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} all_columns = set(unambiguous_columns) - for table, columns in source_columns[1:]: + for table, columns in source_columns_pairs[1:]: unique = self._find_unique_columns(columns) ambiguous = set(all_columns).intersection(unique) all_columns.update(columns) + for column in ambiguous: unambiguous_columns.pop(column, None) for column in unique.difference(ambiguous): @@ -591,7 +611,7 @@ class Resolver: return unambiguous_columns @staticmethod - def _find_unique_columns(columns): + def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: """ Find the unique columns in a list of columns. @@ -601,7 +621,7 @@ class Resolver: This is necessary because duplicate column names are ambiguous. """ - counts = {} + counts: t.Dict[str, int] = {} for column in columns: counts[column] = counts.get(column, 0) + 1 return {column for column, count in counts.items() if count == 1} |