diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 5d8e0d9..65593bd 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -2,7 +2,7 @@ import itertools from sqlglot import alias, exp from sqlglot.helper import csv_reader -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import Scope, traverse_scope def qualify_tables(expression, db=None, catalog=None, schema=None): @@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): """ sequence = itertools.count() + next_name = lambda: f"_q_{next(sequence)}" + for scope in traverse_scope(expression): for derived_table in scope.ctes + scope.derived_tables: if not derived_table.args.get("alias"): @@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): source = source.replace( alias( source.copy(), - source.this if identifier else f"_q_{next(sequence)}", + source.this if identifier else next_name(), table=True, ) ) @@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): schema.add_table( source, {k: type(v).__name__ for k, v in zip(header, columns)} ) + elif isinstance(source, Scope) and source.is_udtf: + udtf = source.expression + table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) + udtf.set("alias", table_alias) + + if not table_alias.name: + table_alias.set("this", next_name()) return expression |