Edit on GitHub

sqlglot.optimizer.qualify_tables

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import DialectType
  8from sqlglot.helper import csv_reader, name_sequence
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10from sqlglot.schema import Schema
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot._typing import E
 14
 15
 16def qualify_tables(
 17    expression: E,
 18    db: t.Optional[str | exp.Identifier] = None,
 19    catalog: t.Optional[str | exp.Identifier] = None,
 20    schema: t.Optional[Schema] = None,
 21    dialect: DialectType = None,
 22) -> E:
 23    """
 24    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 25    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 26
 27    Examples:
 28        >>> import sqlglot
 29        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 30        >>> qualify_tables(expression, db="db").sql()
 31        'SELECT 1 FROM db.tbl AS tbl'
 32        >>>
 33        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 34        >>> qualify_tables(expression).sql()
 35        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 36
 37    Args:
 38        expression: Expression to qualify
 39        db: Database name
 40        catalog: Catalog name
 41        schema: A schema to populate
 42        dialect: The dialect to parse catalog and schema into.
 43
 44    Returns:
 45        The qualified expression.
 46    """
 47    next_alias_name = name_sequence("_q_")
 48    db = exp.parse_identifier(db, dialect=dialect) if db else None
 49    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 50
 51    def _qualify(table: exp.Table) -> None:
 52        if isinstance(table.this, exp.Identifier):
 53            if not table.args.get("db"):
 54                table.set("db", db)
 55            if not table.args.get("catalog") and table.args.get("db"):
 56                table.set("catalog", catalog)
 57
 58    if (db or catalog) and not isinstance(expression, exp.Query):
 59        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 60            if isinstance(node, exp.Table):
 61                _qualify(node)
 62
 63    for scope in traverse_scope(expression):
 64        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 65            if isinstance(derived_table, exp.Subquery):
 66                unnested = derived_table.unnest()
 67                if isinstance(unnested, exp.Table):
 68                    joins = unnested.args.pop("joins", None)
 69                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 70                    derived_table.this.set("joins", joins)
 71
 72            if not derived_table.args.get("alias"):
 73                alias_ = next_alias_name()
 74                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 75                scope.rename_source(None, alias_)
 76
 77            pivots = derived_table.args.get("pivots")
 78            if pivots and not pivots[0].alias:
 79                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 80
 81        table_aliases = {}
 82
 83        for name, source in scope.sources.items():
 84            if isinstance(source, exp.Table):
 85                pivots = pivots = source.args.get("pivots")
 86                if not source.alias:
 87                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 88                    if pivots and pivots[0].alias == name:
 89                        name = source.name
 90
 91                    # Mutates the source by attaching an alias to it
 92                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 93
 94                table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
 95                    source.alias
 96                )
 97
 98                _qualify(source)
 99
100                if pivots and not pivots[0].alias:
101                    pivots[0].set(
102                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
103                    )
104
105                if schema and isinstance(source.this, exp.ReadCSV):
106                    with csv_reader(source.this) as reader:
107                        header = next(reader)
108                        columns = next(reader)
109                        schema.add_table(
110                            source,
111                            {k: type(v).__name__ for k, v in zip(header, columns)},
112                            match_depth=False,
113                        )
114            elif isinstance(source, Scope) and source.is_udtf:
115                udtf = source.expression
116                table_alias = udtf.args.get("alias") or exp.TableAlias(
117                    this=exp.to_identifier(next_alias_name())
118                )
119                udtf.set("alias", table_alias)
120
121                if not table_alias.name:
122                    table_alias.set("this", exp.to_identifier(next_alias_name()))
123                if isinstance(udtf, exp.Values) and not table_alias.columns:
124                    for i, e in enumerate(udtf.expressions[0].expressions):
125                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
126            else:
127                for node in scope.walk():
128                    if (
129                        isinstance(node, exp.Table)
130                        and not node.alias
131                        and isinstance(node.parent, (exp.From, exp.Join))
132                    ):
133                        # Mutates the table by attaching an alias to it
134                        alias(node, node.name, copy=False, table=True)
135
136        for column in scope.columns:
137            if column.db:
138                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
139
140                if table_alias:
141                    for p in exp.COLUMN_PARTS[1:]:
142                        column.set(p, None)
143                    column.set("table", table_alias)
144
145    return expression
def qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> ~E:
 17def qualify_tables(
 18    expression: E,
 19    db: t.Optional[str | exp.Identifier] = None,
 20    catalog: t.Optional[str | exp.Identifier] = None,
 21    schema: t.Optional[Schema] = None,
 22    dialect: DialectType = None,
 23) -> E:
 24    """
 25    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 26    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 27
 28    Examples:
 29        >>> import sqlglot
 30        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 31        >>> qualify_tables(expression, db="db").sql()
 32        'SELECT 1 FROM db.tbl AS tbl'
 33        >>>
 34        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 35        >>> qualify_tables(expression).sql()
 36        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 37
 38    Args:
 39        expression: Expression to qualify
 40        db: Database name
 41        catalog: Catalog name
 42        schema: A schema to populate
 43        dialect: The dialect to parse catalog and schema into.
 44
 45    Returns:
 46        The qualified expression.
 47    """
 48    next_alias_name = name_sequence("_q_")
 49    db = exp.parse_identifier(db, dialect=dialect) if db else None
 50    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 51
 52    def _qualify(table: exp.Table) -> None:
 53        if isinstance(table.this, exp.Identifier):
 54            if not table.args.get("db"):
 55                table.set("db", db)
 56            if not table.args.get("catalog") and table.args.get("db"):
 57                table.set("catalog", catalog)
 58
 59    if (db or catalog) and not isinstance(expression, exp.Query):
 60        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 61            if isinstance(node, exp.Table):
 62                _qualify(node)
 63
 64    for scope in traverse_scope(expression):
 65        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 66            if isinstance(derived_table, exp.Subquery):
 67                unnested = derived_table.unnest()
 68                if isinstance(unnested, exp.Table):
 69                    joins = unnested.args.pop("joins", None)
 70                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 71                    derived_table.this.set("joins", joins)
 72
 73            if not derived_table.args.get("alias"):
 74                alias_ = next_alias_name()
 75                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 76                scope.rename_source(None, alias_)
 77
 78            pivots = derived_table.args.get("pivots")
 79            if pivots and not pivots[0].alias:
 80                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 81
 82        table_aliases = {}
 83
 84        for name, source in scope.sources.items():
 85            if isinstance(source, exp.Table):
 86                pivots = pivots = source.args.get("pivots")
 87                if not source.alias:
 88                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 89                    if pivots and pivots[0].alias == name:
 90                        name = source.name
 91
 92                    # Mutates the source by attaching an alias to it
 93                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 94
 95                table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
 96                    source.alias
 97                )
 98
 99                _qualify(source)
100
101                if pivots and not pivots[0].alias:
102                    pivots[0].set(
103                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
104                    )
105
106                if schema and isinstance(source.this, exp.ReadCSV):
107                    with csv_reader(source.this) as reader:
108                        header = next(reader)
109                        columns = next(reader)
110                        schema.add_table(
111                            source,
112                            {k: type(v).__name__ for k, v in zip(header, columns)},
113                            match_depth=False,
114                        )
115            elif isinstance(source, Scope) and source.is_udtf:
116                udtf = source.expression
117                table_alias = udtf.args.get("alias") or exp.TableAlias(
118                    this=exp.to_identifier(next_alias_name())
119                )
120                udtf.set("alias", table_alias)
121
122                if not table_alias.name:
123                    table_alias.set("this", exp.to_identifier(next_alias_name()))
124                if isinstance(udtf, exp.Values) and not table_alias.columns:
125                    for i, e in enumerate(udtf.expressions[0].expressions):
126                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
127            else:
128                for node in scope.walk():
129                    if (
130                        isinstance(node, exp.Table)
131                        and not node.alias
132                        and isinstance(node.parent, (exp.From, exp.Join))
133                    ):
134                        # Mutates the table by attaching an alias to it
135                        alias(node, node.name, copy=False, table=True)
136
137        for column in scope.columns:
138            if column.db:
139                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
140
141                if table_alias:
142                    for p in exp.COLUMN_PARTS[1:]:
143                        column.set(p, None)
144                    column.set("table", table_alias)
145
146    return expression

Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
>>> qualify_tables(expression).sql()
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
  • expression: Expression to qualify
  • db: Database name
  • catalog: Catalog name
  • schema: A schema to populate
  • dialect: The dialect to parse catalog and schema into.
Returns:

The qualified expression.