summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_tables.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
commit376de8b6892deca7dc5d83035c047f1e13eb67ea (patch)
tree334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot/optimizer/qualify_tables.py
parentReleasing debian version 20.9.0-1. (diff)
downloadsqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz
sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r--sqlglot/optimizer/qualify_tables.py31
1 files changed, 25 insertions, 6 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index e0fe641..d460e81 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -4,12 +4,14 @@ import itertools
import typing as t
from sqlglot import alias, exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def qualify_tables(
expression: E,
@@ -46,6 +48,18 @@ def qualify_tables(
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
+ def _qualify(table: exp.Table) -> None:
+ if isinstance(table.this, exp.Identifier):
+ if not table.args.get("db"):
+ table.set("db", db)
+ if not table.args.get("catalog") and table.args.get("db"):
+ table.set("catalog", catalog)
+
+ if not isinstance(expression, exp.Subqueryable):
+ for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
+ if isinstance(node, exp.Table):
+ _qualify(node)
+
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if isinstance(derived_table, exp.Subquery):
@@ -66,11 +80,7 @@ def qualify_tables(
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
- if isinstance(source.this, exp.Identifier):
- if not source.args.get("db"):
- source.set("db", db)
- if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", catalog)
+ _qualify(source)
pivots = pivots = source.args.get("pivots")
if not source.alias:
@@ -107,5 +117,14 @@ def qualify_tables(
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
+ else:
+ for node, parent, _ in scope.walk():
+ if (
+ isinstance(node, exp.Table)
+ and not node.alias
+ and isinstance(parent, (exp.From, exp.Join))
+ ):
+ # Mutates the table by attaching an alias to it
+ alias(node, node.name, copy=False, table=True)
return expression