sqlglot.optimizer.qualify_tables
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import DialectType 8from sqlglot.helper import csv_reader, name_sequence 9from sqlglot.optimizer.scope import Scope, traverse_scope 10from sqlglot.schema import Schema 11 12if t.TYPE_CHECKING: 13 from sqlglot._typing import E 14 15 16def qualify_tables( 17 expression: E, 18 db: t.Optional[str | exp.Identifier] = None, 19 catalog: t.Optional[str | exp.Identifier] = None, 20 schema: t.Optional[Schema] = None, 21 dialect: DialectType = None, 22) -> E: 23 """ 24 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 25 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 26 27 Examples: 28 >>> import sqlglot 29 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 30 >>> qualify_tables(expression, db="db").sql() 31 'SELECT 1 FROM db.tbl AS tbl' 32 >>> 33 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 34 >>> qualify_tables(expression).sql() 35 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 36 37 Args: 38 expression: Expression to qualify 39 db: Database name 40 catalog: Catalog name 41 schema: A schema to populate 42 dialect: The dialect to parse catalog and schema into. 43 44 Returns: 45 The qualified expression. 46 """ 47 next_alias_name = name_sequence("_q_") 48 db = exp.parse_identifier(db, dialect=dialect) if db else None 49 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 50 51 def _qualify(table: exp.Table) -> None: 52 if isinstance(table.this, exp.Identifier): 53 if not table.args.get("db"): 54 table.set("db", db) 55 if not table.args.get("catalog") and table.args.get("db"): 56 table.set("catalog", catalog) 57 58 if (db or catalog) and not isinstance(expression, exp.Query): 59 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 60 if isinstance(node, exp.Table): 61 _qualify(node) 62 63 for scope in traverse_scope(expression): 64 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 65 if isinstance(derived_table, exp.Subquery): 66 unnested = derived_table.unnest() 67 if isinstance(unnested, exp.Table): 68 joins = unnested.args.pop("joins", None) 69 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 70 derived_table.this.set("joins", joins) 71 72 if not derived_table.args.get("alias"): 73 alias_ = next_alias_name() 74 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 75 scope.rename_source(None, alias_) 76 77 pivots = derived_table.args.get("pivots") 78 if pivots and not pivots[0].alias: 79 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 80 81 table_aliases = {} 82 83 for name, source in scope.sources.items(): 84 if isinstance(source, exp.Table): 85 pivots = pivots = source.args.get("pivots") 86 if not source.alias: 87 # Don't add the pivot's alias to the pivoted table, use the table's name instead 88 if pivots and pivots[0].alias == name: 89 name = source.name 90 91 # Mutates the source by attaching an alias to it 92 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 93 94 table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( 95 source.alias 96 ) 97 98 _qualify(source) 99 100 if pivots and not pivots[0].alias: 101 pivots[0].set( 102 "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) 103 ) 104 105 if schema and isinstance(source.this, exp.ReadCSV): 106 with csv_reader(source.this) as reader: 107 header = next(reader) 108 columns = next(reader) 109 schema.add_table( 110 source, 111 {k: type(v).__name__ for k, v in zip(header, columns)}, 112 match_depth=False, 113 ) 114 elif isinstance(source, Scope) and source.is_udtf: 115 udtf = source.expression 116 table_alias = udtf.args.get("alias") or exp.TableAlias( 117 this=exp.to_identifier(next_alias_name()) 118 ) 119 udtf.set("alias", table_alias) 120 121 if not table_alias.name: 122 table_alias.set("this", exp.to_identifier(next_alias_name())) 123 if isinstance(udtf, exp.Values) and not table_alias.columns: 124 for i, e in enumerate(udtf.expressions[0].expressions): 125 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 126 else: 127 for node in scope.walk(): 128 if ( 129 isinstance(node, exp.Table) 130 and not node.alias 131 and isinstance(node.parent, (exp.From, exp.Join)) 132 ): 133 # Mutates the table by attaching an alias to it 134 alias(node, node.name, copy=False, table=True) 135 136 for column in scope.columns: 137 if column.db: 138 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 139 140 if table_alias: 141 for p in exp.COLUMN_PARTS[1:]: 142 column.set(p, None) 143 column.set("table", table_alias) 144 145 return expression
def
qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> ~E:
17def qualify_tables( 18 expression: E, 19 db: t.Optional[str | exp.Identifier] = None, 20 catalog: t.Optional[str | exp.Identifier] = None, 21 schema: t.Optional[Schema] = None, 22 dialect: DialectType = None, 23) -> E: 24 """ 25 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 26 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 27 28 Examples: 29 >>> import sqlglot 30 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 31 >>> qualify_tables(expression, db="db").sql() 32 'SELECT 1 FROM db.tbl AS tbl' 33 >>> 34 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 35 >>> qualify_tables(expression).sql() 36 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 37 38 Args: 39 expression: Expression to qualify 40 db: Database name 41 catalog: Catalog name 42 schema: A schema to populate 43 dialect: The dialect to parse catalog and schema into. 44 45 Returns: 46 The qualified expression. 47 """ 48 next_alias_name = name_sequence("_q_") 49 db = exp.parse_identifier(db, dialect=dialect) if db else None 50 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 51 52 def _qualify(table: exp.Table) -> None: 53 if isinstance(table.this, exp.Identifier): 54 if not table.args.get("db"): 55 table.set("db", db) 56 if not table.args.get("catalog") and table.args.get("db"): 57 table.set("catalog", catalog) 58 59 if (db or catalog) and not isinstance(expression, exp.Query): 60 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 61 if isinstance(node, exp.Table): 62 _qualify(node) 63 64 for scope in traverse_scope(expression): 65 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 66 if isinstance(derived_table, exp.Subquery): 67 unnested = derived_table.unnest() 68 if isinstance(unnested, exp.Table): 69 joins = unnested.args.pop("joins", None) 70 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 71 derived_table.this.set("joins", joins) 72 73 if not derived_table.args.get("alias"): 74 alias_ = next_alias_name() 75 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 76 scope.rename_source(None, alias_) 77 78 pivots = derived_table.args.get("pivots") 79 if pivots and not pivots[0].alias: 80 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 81 82 table_aliases = {} 83 84 for name, source in scope.sources.items(): 85 if isinstance(source, exp.Table): 86 pivots = pivots = source.args.get("pivots") 87 if not source.alias: 88 # Don't add the pivot's alias to the pivoted table, use the table's name instead 89 if pivots and pivots[0].alias == name: 90 name = source.name 91 92 # Mutates the source by attaching an alias to it 93 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 94 95 table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( 96 source.alias 97 ) 98 99 _qualify(source) 100 101 if pivots and not pivots[0].alias: 102 pivots[0].set( 103 "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) 104 ) 105 106 if schema and isinstance(source.this, exp.ReadCSV): 107 with csv_reader(source.this) as reader: 108 header = next(reader) 109 columns = next(reader) 110 schema.add_table( 111 source, 112 {k: type(v).__name__ for k, v in zip(header, columns)}, 113 match_depth=False, 114 ) 115 elif isinstance(source, Scope) and source.is_udtf: 116 udtf = source.expression 117 table_alias = udtf.args.get("alias") or exp.TableAlias( 118 this=exp.to_identifier(next_alias_name()) 119 ) 120 udtf.set("alias", table_alias) 121 122 if not table_alias.name: 123 table_alias.set("this", exp.to_identifier(next_alias_name())) 124 if isinstance(udtf, exp.Values) and not table_alias.columns: 125 for i, e in enumerate(udtf.expressions[0].expressions): 126 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 127 else: 128 for node in scope.walk(): 129 if ( 130 isinstance(node, exp.Table) 131 and not node.alias 132 and isinstance(node.parent, (exp.From, exp.Join)) 133 ): 134 # Mutates the table by attaching an alias to it 135 alias(node, node.name, copy=False, table=True) 136 137 for column in scope.columns: 138 if column.db: 139 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 140 141 if table_alias: 142 for p in exp.COLUMN_PARTS[1:]: 143 column.set(p, None) 144 column.set("table", table_alias) 145 146 return expression
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
- expression: Expression to qualify
- db: Database name
- catalog: Catalog name
- schema: A schema to populate
- dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.