diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/optimizer/isolate_table_selects.py | 7 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 54 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 4 |
4 files changed, 36 insertions, 30 deletions
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index bba0878..719a77e 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1 +1,2 @@ from sqlglot.optimizer.optimizer import RULES, optimize +from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 652cdef..5bd7b30 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -1,15 +1,18 @@ from sqlglot import alias, exp from sqlglot.errors import OptimizeError from sqlglot.optimizer.scope import traverse_scope +from sqlglot.schema import ensure_schema -def isolate_table_selects(expression): +def isolate_table_selects(expression, schema=None): + schema = ensure_schema(schema) + for scope in traverse_scope(expression): if len(scope.selected_sources) == 1: continue for (_, source) in scope.selected_sources.values(): - if not isinstance(source, exp.Table): + if not isinstance(source, exp.Table) or not schema.column_names(source): continue if not source.alias: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 8da4e43..54425a8 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -1,7 +1,8 @@ import itertools +import typing as t from sqlglot import alias, exp -from sqlglot.errors import OptimizeError, SchemaError +from sqlglot.errors import OptimizeError from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver): column_table = column.table column_name = column.name - if ( - column_table - and column_table in scope.sources - and column_name not in resolver.get_source_columns(column_table) - ): - raise OptimizeError(f"Unknown column: {column_name}") + if column_table and column_table in scope.sources: + source_columns = resolver.get_source_columns(column_table) + if source_columns and column_name not in source_columns: + raise OptimizeError(f"Unknown column: {column_name}") if not column_table: column_table = resolver.get_table(column_name) if not scope.is_subquery and not scope.is_udtf: - if column_name not in resolver.all_columns: - raise OptimizeError(f"Unknown column: {column_name}") - if column_table is None: raise OptimizeError(f"Ambiguous column: {column_name}") @@ -265,6 +261,10 @@ def _expand_stars(scope, resolver): if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) + if not columns: + raise OptimizeError( + f"Table has no schema/columns. Cannot expand star for table: {table}." + ) table_id = id(table) for name in columns: if name not in except_columns.get(table_id, set()): @@ -306,16 +306,11 @@ def _qualify_outputs(scope): for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.selects, scope.outer_column_list) ): - if isinstance(selection, exp.Column): - # convoluted setter because a simple selection.replace(alias) would require a copy - alias_ = alias(exp.column(""), alias=selection.name) - alias_.set("this", selection) - selection = alias_ - elif isinstance(selection, exp.Subquery): - if not selection.alias: + if isinstance(selection, exp.Subquery): + if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias): - alias_ = alias(exp.column(""), f"_col_{i}") + alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") alias_.set("this", selection) selection = alias_ @@ -346,20 +341,30 @@ class _Resolver: self._unambiguous_columns = None self._all_columns = None - def get_table(self, column_name): + def get_table(self, column_name: str) -> t.Optional[str]: """ Get the table for a column name. Args: - column_name (str) + column_name: The column name to find the table for. Returns: - (str) table name + The table name if it can be found/inferred. """ if self._unambiguous_columns is None: self._unambiguous_columns = self._get_unambiguous_columns( self._get_all_source_columns() ) - return self._unambiguous_columns.get(column_name) + + table = self._unambiguous_columns.get(column_name) + + if not table: + sources_without_schema = tuple( + source for source, columns in self._get_all_source_columns().items() if not columns + ) + if len(sources_without_schema) == 1: + return sources_without_schema[0] + + return table @property def all_columns(self): @@ -379,10 +384,7 @@ class _Resolver: # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): - try: - return self.schema.column_names(source, only_visible) - except Exception as e: - raise SchemaError(str(e)) from e + return self.schema.column_names(source, only_visible) if isinstance(source, Scope) and isinstance(source.expression, exp.Values): return source.expression.alias_column_names diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 6125e4e..5a3ed5a 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -230,7 +230,7 @@ class Scope: column for scope in self.subquery_scopes for column in scope.external_columns ] - named_outputs = {e.alias_or_name for e in self.expression.expressions} + named_selects = set(self.expression.named_selects) self._columns = [] for column in columns + external_columns: @@ -238,7 +238,7 @@ class Scope: if ( not ancestor or column.table - or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint)) + or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): self._columns.append(column) |