summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/qualify_columns.py50
1 files changed, 22 insertions, 28 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 1656727..5c27bc3 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -6,7 +6,7 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
-from sqlglot.helper import seq_get
+from sqlglot.helper import seq_get, SingleValuedMapping
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -586,8 +586,8 @@ class Resolver:
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
- self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
- self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
+ self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
+ self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
@@ -640,7 +640,7 @@ class Resolver:
}
return self._all_columns
- def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
+ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
"""Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
@@ -662,10 +662,15 @@ class Resolver:
else:
column_aliases = []
- # If the source's columns are aliased, their aliases shadow the corresponding column names
- return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
+ if column_aliases:
+ # If the source's columns are aliased, their aliases shadow the corresponding column names.
+ # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
+ return [
+ alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
+ ]
+ return columns
- def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
+ def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:
self._source_columns = {
source_name: self.get_source_columns(source_name)
@@ -676,8 +681,8 @@ class Resolver:
return self._source_columns
def _get_unambiguous_columns(
- self, source_columns: t.Dict[str, t.List[str]]
- ) -> t.Dict[str, str]:
+ self, source_columns: t.Dict[str, t.Sequence[str]]
+ ) -> t.Mapping[str, str]:
"""
Find all the unambiguous columns in sources.
@@ -693,12 +698,17 @@ class Resolver:
source_columns_pairs = list(source_columns.items())
first_table, first_columns = source_columns_pairs[0]
- unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
+
+ if len(source_columns_pairs) == 1:
+ # Performance optimization - avoid copying first_columns if there is only one table.
+ return SingleValuedMapping(first_columns, first_table)
+
+ unambiguous_columns = {col: first_table for col in first_columns}
all_columns = set(unambiguous_columns)
for table, columns in source_columns_pairs[1:]:
- unique = self._find_unique_columns(columns)
- ambiguous = set(all_columns).intersection(unique)
+ unique = set(columns)
+ ambiguous = all_columns.intersection(unique)
all_columns.update(columns)
for column in ambiguous:
@@ -707,19 +717,3 @@ class Resolver:
unambiguous_columns[column] = table
return unambiguous_columns
-
- @staticmethod
- def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
- """
- Find the unique columns in a list of columns.
-
- Example:
- >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
- ['a', 'c']
-
- This is necessary because duplicate column names are ambiguous.
- """
- counts: t.Dict[str, int] = {}
- for column in columns:
- counts[column] = counts.get(column, 0) + 1
- return {column for column, count in counts.items() if count == 1}