diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_tables.py')
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 45 |
1 files changed, 31 insertions, 14 deletions
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 1b451a6..fcc5f26 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,11 +1,19 @@ import itertools +import typing as t from sqlglot import alias, exp -from sqlglot.helper import csv_reader +from sqlglot._typing import E +from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import Schema -def qualify_tables(expression, db=None, catalog=None, schema=None): +def qualify_tables( + expression: E, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[Schema] = None, +) -> E: """ Rewrite sqlglot AST to have fully qualified tables. Additionally, this replaces "join constructs" (*) by equivalent SELECT * subqueries. @@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' Args: - expression (sqlglot.Expression): expression to qualify - db (str): Database name - catalog (str): Catalog name + expression: Expression to qualify + db: Database name + catalog: Catalog name schema: A schema to populate Returns: - sqlglot.Expression: qualified expression + The qualified expression. (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html """ - sequence = itertools.count() - - next_name = lambda: f"_q_{next(sequence)}" + next_alias_name = name_sequence("_q_") for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): @@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) if not derived_table.args.get("alias"): - alias_ = f"_q_{next(sequence)}" + alias_ = next_alias_name() derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) + pivots = derived_table.args.get("pivots") + if pivots and not pivots[0].alias: + pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) + for name, source in scope.sources.items(): if isinstance(source, exp.Table): if isinstance(source.this, exp.Identifier): @@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): if not source.alias: source = source.replace( alias( - source.copy(), - name if name else next_name(), + source, + name or source.name or next_alias_name(), + copy=True, table=True, ) ) + pivots = source.args.get("pivots") + if pivots and not pivots[0].alias: + pivots[0].set( + "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) + ) + if schema and isinstance(source.this, exp.ReadCSV): with csv_reader(source.this) as reader: header = next(reader) @@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression - table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) + table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name()) udtf.set("alias", table_alias) if not table_alias.name: - table_alias.set("this", next_name()) + table_alias.set("this", next_alias_name()) if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) |