summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_tables.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/optimizer/qualify_tables.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r--sqlglot/optimizer/qualify_tables.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 3a43e8f..57ecabe 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
@@ -10,9 +13,10 @@ from sqlglot.schema import Schema
def qualify_tables(
expression: E,
- db: t.Optional[str] = None,
- catalog: t.Optional[str] = None,
+ db: t.Optional[str | exp.Identifier] = None,
+ catalog: t.Optional[str | exp.Identifier] = None,
schema: t.Optional[Schema] = None,
+ dialect: DialectType = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -33,11 +37,14 @@ def qualify_tables(
db: Database name
catalog: Catalog name
schema: A schema to populate
+ dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.
"""
next_alias_name = name_sequence("_q_")
+ db = exp.parse_identifier(db, dialect=dialect) if db else None
+ catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@@ -61,9 +68,9 @@ def qualify_tables(
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
- source.set("db", exp.to_identifier(db))
+ source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", exp.to_identifier(catalog))
+ source.set("catalog", catalog)
if not source.alias:
# Mutates the source by attaching an alias to it