sqlglot.optimizer.qualify_tables
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot._typing import E 6from sqlglot.helper import csv_reader, name_sequence 7from sqlglot.optimizer.scope import Scope, traverse_scope 8from sqlglot.schema import Schema 9 10 11def qualify_tables( 12 expression: E, 13 db: t.Optional[str] = None, 14 catalog: t.Optional[str] = None, 15 schema: t.Optional[Schema] = None, 16) -> E: 17 """ 18 Rewrite sqlglot AST to have fully qualified tables. Additionally, this 19 replaces "join constructs" (*) by equivalent SELECT * subqueries. 20 21 Examples: 22 >>> import sqlglot 23 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 24 >>> qualify_tables(expression, db="db").sql() 25 'SELECT 1 FROM db.tbl AS tbl' 26 >>> 27 >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") 28 >>> qualify_tables(expression).sql() 29 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' 30 31 Args: 32 expression: Expression to qualify 33 db: Database name 34 catalog: Catalog name 35 schema: A schema to populate 36 37 Returns: 38 The qualified expression. 39 40 (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html 41 """ 42 next_alias_name = name_sequence("_q_") 43 44 for scope in traverse_scope(expression): 45 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 46 # Expand join construct 47 if isinstance(derived_table, exp.Subquery): 48 unnested = derived_table.unnest() 49 if isinstance(unnested, exp.Table): 50 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 51 52 if not derived_table.args.get("alias"): 53 alias_ = next_alias_name() 54 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 55 scope.rename_source(None, alias_) 56 57 pivots = derived_table.args.get("pivots") 58 if pivots and not pivots[0].alias: 59 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 60 61 for name, source in scope.sources.items(): 62 if isinstance(source, exp.Table): 63 if isinstance(source.this, exp.Identifier): 64 if not source.args.get("db"): 65 source.set("db", exp.to_identifier(db)) 66 if not source.args.get("catalog"): 67 source.set("catalog", exp.to_identifier(catalog)) 68 69 if not source.alias: 70 source = source.replace( 71 alias( 72 source, 73 name or source.name or next_alias_name(), 74 copy=True, 75 table=True, 76 ) 77 ) 78 79 pivots = source.args.get("pivots") 80 if pivots and not pivots[0].alias: 81 pivots[0].set( 82 "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) 83 ) 84 85 if schema and isinstance(source.this, exp.ReadCSV): 86 with csv_reader(source.this) as reader: 87 header = next(reader) 88 columns = next(reader) 89 schema.add_table( 90 source, {k: type(v).__name__ for k, v in zip(header, columns)} 91 ) 92 elif isinstance(source, Scope) and source.is_udtf: 93 udtf = source.expression 94 table_alias = udtf.args.get("alias") or exp.TableAlias( 95 this=exp.to_identifier(next_alias_name()) 96 ) 97 udtf.set("alias", table_alias) 98 99 if not table_alias.name: 100 table_alias.set("this", exp.to_identifier(next_alias_name())) 101 if isinstance(udtf, exp.Values) and not table_alias.columns: 102 for i, e in enumerate(udtf.expressions[0].expressions): 103 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 104 105 return expression
def
qualify_tables( expression: ~E, db: Optional[str] = None, catalog: Optional[str] = None, schema: Optional[sqlglot.schema.Schema] = None) -> ~E:
12def qualify_tables( 13 expression: E, 14 db: t.Optional[str] = None, 15 catalog: t.Optional[str] = None, 16 schema: t.Optional[Schema] = None, 17) -> E: 18 """ 19 Rewrite sqlglot AST to have fully qualified tables. Additionally, this 20 replaces "join constructs" (*) by equivalent SELECT * subqueries. 21 22 Examples: 23 >>> import sqlglot 24 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 25 >>> qualify_tables(expression, db="db").sql() 26 'SELECT 1 FROM db.tbl AS tbl' 27 >>> 28 >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") 29 >>> qualify_tables(expression).sql() 30 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' 31 32 Args: 33 expression: Expression to qualify 34 db: Database name 35 catalog: Catalog name 36 schema: A schema to populate 37 38 Returns: 39 The qualified expression. 40 41 (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html 42 """ 43 next_alias_name = name_sequence("_q_") 44 45 for scope in traverse_scope(expression): 46 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 47 # Expand join construct 48 if isinstance(derived_table, exp.Subquery): 49 unnested = derived_table.unnest() 50 if isinstance(unnested, exp.Table): 51 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 52 53 if not derived_table.args.get("alias"): 54 alias_ = next_alias_name() 55 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 56 scope.rename_source(None, alias_) 57 58 pivots = derived_table.args.get("pivots") 59 if pivots and not pivots[0].alias: 60 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 61 62 for name, source in scope.sources.items(): 63 if isinstance(source, exp.Table): 64 if isinstance(source.this, exp.Identifier): 65 if not source.args.get("db"): 66 source.set("db", exp.to_identifier(db)) 67 if not source.args.get("catalog"): 68 source.set("catalog", exp.to_identifier(catalog)) 69 70 if not source.alias: 71 source = source.replace( 72 alias( 73 source, 74 name or source.name or next_alias_name(), 75 copy=True, 76 table=True, 77 ) 78 ) 79 80 pivots = source.args.get("pivots") 81 if pivots and not pivots[0].alias: 82 pivots[0].set( 83 "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) 84 ) 85 86 if schema and isinstance(source.this, exp.ReadCSV): 87 with csv_reader(source.this) as reader: 88 header = next(reader) 89 columns = next(reader) 90 schema.add_table( 91 source, {k: type(v).__name__ for k, v in zip(header, columns)} 92 ) 93 elif isinstance(source, Scope) and source.is_udtf: 94 udtf = source.expression 95 table_alias = udtf.args.get("alias") or exp.TableAlias( 96 this=exp.to_identifier(next_alias_name()) 97 ) 98 udtf.set("alias", table_alias) 99 100 if not table_alias.name: 101 table_alias.set("this", exp.to_identifier(next_alias_name())) 102 if isinstance(udtf, exp.Values) and not table_alias.columns: 103 for i, e in enumerate(udtf.expressions[0].expressions): 104 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 105 106 return expression
Rewrite sqlglot AST to have fully qualified tables. Additionally, this replaces "join constructs" (*) by equivalent SELECT * subqueries.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") >>> qualify_tables(expression).sql() 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Arguments:
- expression: Expression to qualify
- db: Database name
- catalog: Catalog name
- schema: A schema to populate
Returns:
The qualified expression.
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html