summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py54
1 files changed, 28 insertions, 26 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 8da4e43..54425a8 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -1,7 +1,8 @@
import itertools
+import typing as t
from sqlglot import alias, exp
-from sqlglot.errors import OptimizeError, SchemaError
+from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver):
column_table = column.table
column_name = column.name
- if (
- column_table
- and column_table in scope.sources
- and column_name not in resolver.get_source_columns(column_table)
- ):
- raise OptimizeError(f"Unknown column: {column_name}")
+ if column_table and column_table in scope.sources:
+ source_columns = resolver.get_source_columns(column_table)
+ if source_columns and column_name not in source_columns:
+ raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf:
- if column_name not in resolver.all_columns:
- raise OptimizeError(f"Unknown column: {column_name}")
-
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
@@ -265,6 +261,10 @@ def _expand_stars(scope, resolver):
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
+ if not columns:
+ raise OptimizeError(
+ f"Table has no schema/columns. Cannot expand star for table: {table}."
+ )
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
@@ -306,16 +306,11 @@ def _qualify_outputs(scope):
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
- if isinstance(selection, exp.Column):
- # convoluted setter because a simple selection.replace(alias) would require a copy
- alias_ = alias(exp.column(""), alias=selection.name)
- alias_.set("this", selection)
- selection = alias_
- elif isinstance(selection, exp.Subquery):
- if not selection.alias:
+ if isinstance(selection, exp.Subquery):
+ if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
- alias_ = alias(exp.column(""), f"_col_{i}")
+ alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection)
selection = alias_
@@ -346,20 +341,30 @@ class _Resolver:
self._unambiguous_columns = None
self._all_columns = None
- def get_table(self, column_name):
+ def get_table(self, column_name: str) -> t.Optional[str]:
"""
Get the table for a column name.
Args:
- column_name (str)
+ column_name: The column name to find the table for.
Returns:
- (str) table name
+ The table name if it can be found/inferred.
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
- return self._unambiguous_columns.get(column_name)
+
+ table = self._unambiguous_columns.get(column_name)
+
+ if not table:
+ sources_without_schema = tuple(
+ source for source, columns in self._get_all_source_columns().items() if not columns
+ )
+ if len(sources_without_schema) == 1:
+ return sources_without_schema[0]
+
+ return table
@property
def all_columns(self):
@@ -379,10 +384,7 @@ class _Resolver:
# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
- try:
- return self.schema.column_names(source, only_visible)
- except Exception as e:
- raise SchemaError(str(e)) from e
+ return self.schema.column_names(source, only_visible)
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names