diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 50 |
1 files changed, 22 insertions, 28 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 1656727..5c27bc3 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -6,7 +6,7 @@ import typing as t from sqlglot import alias, exp from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError -from sqlglot.helper import seq_get +from sqlglot.helper import seq_get, SingleValuedMapping from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -586,8 +586,8 @@ class Resolver: def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): self.scope = scope self.schema = schema - self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None - self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None + self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None + self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None self._all_columns: t.Optional[t.Set[str]] = None self._infer_schema = infer_schema @@ -640,7 +640,7 @@ class Resolver: } return self._all_columns - def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: + def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: """Resolve the source columns for a given source `name`.""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") @@ -662,10 +662,15 @@ class Resolver: else: column_aliases = [] - # If the source's columns are aliased, their aliases shadow the corresponding column names - return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] + if column_aliases: + # If the source's columns are aliased, their aliases shadow the corresponding column names. + # This can be expensive if there are lots of columns, so only do this if column_aliases exist. + return [ + alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) + ] + return columns - def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: + def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: if self._source_columns is None: self._source_columns = { source_name: self.get_source_columns(source_name) @@ -676,8 +681,8 @@ class Resolver: return self._source_columns def _get_unambiguous_columns( - self, source_columns: t.Dict[str, t.List[str]] - ) -> t.Dict[str, str]: + self, source_columns: t.Dict[str, t.Sequence[str]] + ) -> t.Mapping[str, str]: """ Find all the unambiguous columns in sources. @@ -693,12 +698,17 @@ class Resolver: source_columns_pairs = list(source_columns.items()) first_table, first_columns = source_columns_pairs[0] - unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} + + if len(source_columns_pairs) == 1: + # Performance optimization - avoid copying first_columns if there is only one table. + return SingleValuedMapping(first_columns, first_table) + + unambiguous_columns = {col: first_table for col in first_columns} all_columns = set(unambiguous_columns) for table, columns in source_columns_pairs[1:]: - unique = self._find_unique_columns(columns) - ambiguous = set(all_columns).intersection(unique) + unique = set(columns) + ambiguous = all_columns.intersection(unique) all_columns.update(columns) for column in ambiguous: @@ -707,19 +717,3 @@ class Resolver: unambiguous_columns[column] = table return unambiguous_columns - - @staticmethod - def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: - """ - Find the unique columns in a list of columns. - - Example: - >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) - ['a', 'c'] - - This is necessary because duplicate column names are ambiguous. - """ - counts: t.Dict[str, int] = {} - for column in columns: - counts[column] = counts.get(column, 0) + 1 - return {column for column, count in counts.items() if count == 1} |