summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_tables.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r--sqlglot/optimizer/qualify_tables.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
new file mode 100644
index 0000000..9f8b9f5
--- /dev/null
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -0,0 +1,54 @@
+import itertools
+
+from sqlglot import alias, exp
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def qualify_tables(expression, db=None, catalog=None):
+ """
+ Rewrite sqlglot AST to have fully qualified tables.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
+ >>> qualify_tables(expression, db="db").sql()
+ 'SELECT 1 FROM db.tbl AS tbl'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ db (str): Database name
+ catalog (str): Catalog name
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ sequence = itertools.count()
+
+ for scope in traverse_scope(expression):
+ for derived_table in scope.ctes + scope.derived_tables:
+ if not derived_table.args.get("alias"):
+ alias_ = f"_q_{next(sequence)}"
+ derived_table.set(
+ "alias", exp.TableAlias(this=exp.to_identifier(alias_))
+ )
+ scope.rename_source(None, alias_)
+
+ for source in scope.sources.values():
+ if isinstance(source, exp.Table):
+ identifier = isinstance(source.this, exp.Identifier)
+
+ if identifier:
+ if not source.args.get("db"):
+ source.set("db", exp.to_identifier(db))
+ if not source.args.get("catalog"):
+ source.set("catalog", exp.to_identifier(catalog))
+
+ if not isinstance(source.parent, exp.Alias):
+ source.replace(
+ alias(
+ source.copy(),
+ source.this if identifier else f"_q_{next(sequence)}",
+ table=True,
+ )
+ )
+
+ return expression