diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-31 05:44:41 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-31 05:44:41 +0000 |
commit | 376de8b6892deca7dc5d83035c047f1e13eb67ea (patch) | |
tree | 334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot/optimizer/qualify_tables.py | |
parent | Releasing debian version 20.9.0-1. (diff) | |
download | sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip |
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index e0fe641..d460e81 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -4,12 +4,14 @@ import itertools import typing as t from sqlglot import alias, exp -from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema +if t.TYPE_CHECKING: + from sqlglot._typing import E + def qualify_tables( expression: E, @@ -46,6 +48,18 @@ def qualify_tables( db = exp.parse_identifier(db, dialect=dialect) if db else None catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None + def _qualify(table: exp.Table) -> None: + if isinstance(table.this, exp.Identifier): + if not table.args.get("db"): + table.set("db", db) + if not table.args.get("catalog") and table.args.get("db"): + table.set("catalog", catalog) + + if not isinstance(expression, exp.Subqueryable): + for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): + if isinstance(node, exp.Table): + _qualify(node) + for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): if isinstance(derived_table, exp.Subquery): @@ -66,11 +80,7 @@ def qualify_tables( for name, source in scope.sources.items(): if isinstance(source, exp.Table): - if isinstance(source.this, exp.Identifier): - if not source.args.get("db"): - source.set("db", db) - if not source.args.get("catalog") and source.args.get("db"): - source.set("catalog", catalog) + _qualify(source) pivots = pivots = source.args.get("pivots") if not source.alias: @@ -107,5 +117,14 @@ def qualify_tables( if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) + else: + for node, parent, _ in scope.walk(): + if ( + isinstance(node, exp.Table) + and not node.alias + and isinstance(parent, (exp.From, exp.Join)) + ): + # Mutates the table by attaching an alias to it + alias(node, node.name, copy=False, table=True) return expression |