summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_tables.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r--sqlglot/optimizer/qualify_tables.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 3a43e8f..57ecabe 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
@@ -10,9 +13,10 @@ from sqlglot.schema import Schema
def qualify_tables(
expression: E,
- db: t.Optional[str] = None,
- catalog: t.Optional[str] = None,
+ db: t.Optional[str | exp.Identifier] = None,
+ catalog: t.Optional[str | exp.Identifier] = None,
schema: t.Optional[Schema] = None,
+ dialect: DialectType = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -33,11 +37,14 @@ def qualify_tables(
db: Database name
catalog: Catalog name
schema: A schema to populate
+ dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.
"""
next_alias_name = name_sequence("_q_")
+ db = exp.parse_identifier(db, dialect=dialect) if db else None
+ catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@@ -61,9 +68,9 @@ def qualify_tables(
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
- source.set("db", exp.to_identifier(db))
+ source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", exp.to_identifier(catalog))
+ source.set("catalog", catalog)
if not source.alias:
# Mutates the source by attaching an alias to it