diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 0a31246..6ac39f0 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -3,11 +3,12 @@ import typing as t from sqlglot import alias, exp from sqlglot.errors import OptimizeError +from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema -def qualify_columns(expression, schema): +def qualify_columns(expression, schema, expand_laterals=True): """ Rewrite sqlglot AST to have fully qualified columns. @@ -26,6 +27,9 @@ def qualify_columns(expression, schema): """ schema = ensure_schema(schema) + if not schema.mapping and expand_laterals: + expression = _expand_laterals(expression) + for scope in traverse_scope(expression): resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) @@ -39,6 +43,9 @@ def qualify_columns(expression, schema): _expand_group_by(scope, resolver) _expand_order_by(scope) + if schema.mapping and expand_laterals: + expression = _expand_laterals(expression) + return expression @@ -124,7 +131,7 @@ def _expand_using(scope, resolver): tables[join_table] = None join.args.pop("using") - join.set("on", exp.and_(*conditions)) + join.set("on", exp.and_(*conditions, copy=False)) if column_tables: for column in scope.columns: @@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver): # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", column_table) - elif column_table not in scope.sources: + elif column_table not in scope.sources and ( + not scope.parent or column_table not in scope.parent.sources + ): # structs are used like tables (e.g. "struct"."field"), so they need to be qualified # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) @@ -376,10 +385,13 @@ def _qualify_outputs(scope): if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias) and not selection.is_star: - alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") - alias_.set("this", selection) - selection = alias_ - + selection = alias( + selection, + alias=selection.output_name or f"_col_{i}", + quoted=True + if isinstance(selection, exp.Column) and selection.this.quoted + else None, + ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) |