Edit on GitHub

sqlglot.optimizer.qualify_tables

  1import itertools
  2import typing as t
  3
  4from sqlglot import alias, exp
  5from sqlglot._typing import E
  6from sqlglot.helper import csv_reader, name_sequence
  7from sqlglot.optimizer.scope import Scope, traverse_scope
  8from sqlglot.schema import Schema
  9
 10
 11def qualify_tables(
 12    expression: E,
 13    db: t.Optional[str] = None,
 14    catalog: t.Optional[str] = None,
 15    schema: t.Optional[Schema] = None,
 16) -> E:
 17    """
 18    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 19    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 20
 21    Examples:
 22        >>> import sqlglot
 23        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 24        >>> qualify_tables(expression, db="db").sql()
 25        'SELECT 1 FROM db.tbl AS tbl'
 26        >>>
 27        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 28        >>> qualify_tables(expression).sql()
 29        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 30
 31    Args:
 32        expression: Expression to qualify
 33        db: Database name
 34        catalog: Catalog name
 35        schema: A schema to populate
 36
 37    Returns:
 38        The qualified expression.
 39    """
 40    next_alias_name = name_sequence("_q_")
 41
 42    for scope in traverse_scope(expression):
 43        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 44            if isinstance(derived_table, exp.Subquery):
 45                unnested = derived_table.unnest()
 46                if isinstance(unnested, exp.Table):
 47                    joins = unnested.args.pop("joins", None)
 48                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 49                    derived_table.this.set("joins", joins)
 50
 51            if not derived_table.args.get("alias"):
 52                alias_ = next_alias_name()
 53                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 54                scope.rename_source(None, alias_)
 55
 56            pivots = derived_table.args.get("pivots")
 57            if pivots and not pivots[0].alias:
 58                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 59
 60        for name, source in scope.sources.items():
 61            if isinstance(source, exp.Table):
 62                if isinstance(source.this, exp.Identifier):
 63                    if not source.args.get("db"):
 64                        source.set("db", exp.to_identifier(db))
 65                    if not source.args.get("catalog"):
 66                        source.set("catalog", exp.to_identifier(catalog))
 67
 68                if not source.alias:
 69                    # Mutates the source by attaching an alias to it
 70                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 71
 72                pivots = source.args.get("pivots")
 73                if pivots and not pivots[0].alias:
 74                    pivots[0].set(
 75                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
 76                    )
 77
 78                if schema and isinstance(source.this, exp.ReadCSV):
 79                    with csv_reader(source.this) as reader:
 80                        header = next(reader)
 81                        columns = next(reader)
 82                        schema.add_table(
 83                            source,
 84                            {k: type(v).__name__ for k, v in zip(header, columns)},
 85                            match_depth=False,
 86                        )
 87            elif isinstance(source, Scope) and source.is_udtf:
 88                udtf = source.expression
 89                table_alias = udtf.args.get("alias") or exp.TableAlias(
 90                    this=exp.to_identifier(next_alias_name())
 91                )
 92                udtf.set("alias", table_alias)
 93
 94                if not table_alias.name:
 95                    table_alias.set("this", exp.to_identifier(next_alias_name()))
 96                if isinstance(udtf, exp.Values) and not table_alias.columns:
 97                    for i, e in enumerate(udtf.expressions[0].expressions):
 98                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
 99
100    return expression
def qualify_tables( expression: ~E, db: Optional[str] = None, catalog: Optional[str] = None, schema: Optional[sqlglot.schema.Schema] = None) -> ~E:
 12def qualify_tables(
 13    expression: E,
 14    db: t.Optional[str] = None,
 15    catalog: t.Optional[str] = None,
 16    schema: t.Optional[Schema] = None,
 17) -> E:
 18    """
 19    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 20    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 21
 22    Examples:
 23        >>> import sqlglot
 24        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 25        >>> qualify_tables(expression, db="db").sql()
 26        'SELECT 1 FROM db.tbl AS tbl'
 27        >>>
 28        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 29        >>> qualify_tables(expression).sql()
 30        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 31
 32    Args:
 33        expression: Expression to qualify
 34        db: Database name
 35        catalog: Catalog name
 36        schema: A schema to populate
 37
 38    Returns:
 39        The qualified expression.
 40    """
 41    next_alias_name = name_sequence("_q_")
 42
 43    for scope in traverse_scope(expression):
 44        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 45            if isinstance(derived_table, exp.Subquery):
 46                unnested = derived_table.unnest()
 47                if isinstance(unnested, exp.Table):
 48                    joins = unnested.args.pop("joins", None)
 49                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 50                    derived_table.this.set("joins", joins)
 51
 52            if not derived_table.args.get("alias"):
 53                alias_ = next_alias_name()
 54                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 55                scope.rename_source(None, alias_)
 56
 57            pivots = derived_table.args.get("pivots")
 58            if pivots and not pivots[0].alias:
 59                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 60
 61        for name, source in scope.sources.items():
 62            if isinstance(source, exp.Table):
 63                if isinstance(source.this, exp.Identifier):
 64                    if not source.args.get("db"):
 65                        source.set("db", exp.to_identifier(db))
 66                    if not source.args.get("catalog"):
 67                        source.set("catalog", exp.to_identifier(catalog))
 68
 69                if not source.alias:
 70                    # Mutates the source by attaching an alias to it
 71                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 72
 73                pivots = source.args.get("pivots")
 74                if pivots and not pivots[0].alias:
 75                    pivots[0].set(
 76                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
 77                    )
 78
 79                if schema and isinstance(source.this, exp.ReadCSV):
 80                    with csv_reader(source.this) as reader:
 81                        header = next(reader)
 82                        columns = next(reader)
 83                        schema.add_table(
 84                            source,
 85                            {k: type(v).__name__ for k, v in zip(header, columns)},
 86                            match_depth=False,
 87                        )
 88            elif isinstance(source, Scope) and source.is_udtf:
 89                udtf = source.expression
 90                table_alias = udtf.args.get("alias") or exp.TableAlias(
 91                    this=exp.to_identifier(next_alias_name())
 92                )
 93                udtf.set("alias", table_alias)
 94
 95                if not table_alias.name:
 96                    table_alias.set("this", exp.to_identifier(next_alias_name()))
 97                if isinstance(udtf, exp.Values) and not table_alias.columns:
 98                    for i, e in enumerate(udtf.expressions[0].expressions):
 99                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
100
101    return expression

Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
>>> qualify_tables(expression).sql()
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
  • expression: Expression to qualify
  • db: Database name
  • catalog: Catalog name
  • schema: A schema to populate
Returns:

The qualified expression.