Edit on GitHub

sqlglot.optimizer.qualify_tables

 1import itertools
 2
 3from sqlglot import alias, exp
 4from sqlglot.helper import csv_reader
 5from sqlglot.optimizer.scope import Scope, traverse_scope
 6
 7
 8def qualify_tables(expression, db=None, catalog=None, schema=None):
 9    """
10    Rewrite sqlglot AST to have fully qualified tables.
11
12    Example:
13        >>> import sqlglot
14        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
15        >>> qualify_tables(expression, db="db").sql()
16        'SELECT 1 FROM db.tbl AS tbl'
17
18    Args:
19        expression (sqlglot.Expression): expression to qualify
20        db (str): Database name
21        catalog (str): Catalog name
22        schema: A schema to populate
23    Returns:
24        sqlglot.Expression: qualified expression
25    """
26    sequence = itertools.count()
27
28    next_name = lambda: f"_q_{next(sequence)}"
29
30    for scope in traverse_scope(expression):
31        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
32            if not derived_table.args.get("alias"):
33                alias_ = f"_q_{next(sequence)}"
34                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
35                scope.rename_source(None, alias_)
36
37        for name, source in scope.sources.items():
38            if isinstance(source, exp.Table):
39                if isinstance(source.this, exp.Identifier):
40                    if not source.args.get("db"):
41                        source.set("db", exp.to_identifier(db))
42                    if not source.args.get("catalog"):
43                        source.set("catalog", exp.to_identifier(catalog))
44
45                if not source.alias:
46                    source = source.replace(
47                        alias(
48                            source.copy(),
49                            name if name else next_name(),
50                            table=True,
51                        )
52                    )
53
54                if schema and isinstance(source.this, exp.ReadCSV):
55                    with csv_reader(source.this) as reader:
56                        header = next(reader)
57                        columns = next(reader)
58                        schema.add_table(
59                            source, {k: type(v).__name__ for k, v in zip(header, columns)}
60                        )
61            elif isinstance(source, Scope) and source.is_udtf:
62                udtf = source.expression
63                table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
64                udtf.set("alias", table_alias)
65
66                if not table_alias.name:
67                    table_alias.set("this", next_name())
68                if isinstance(udtf, exp.Values) and not table_alias.columns:
69                    for i, e in enumerate(udtf.expressions[0].expressions):
70                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
71
72    return expression
def qualify_tables(expression, db=None, catalog=None, schema=None):
 9def qualify_tables(expression, db=None, catalog=None, schema=None):
10    """
11    Rewrite sqlglot AST to have fully qualified tables.
12
13    Example:
14        >>> import sqlglot
15        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
16        >>> qualify_tables(expression, db="db").sql()
17        'SELECT 1 FROM db.tbl AS tbl'
18
19    Args:
20        expression (sqlglot.Expression): expression to qualify
21        db (str): Database name
22        catalog (str): Catalog name
23        schema: A schema to populate
24    Returns:
25        sqlglot.Expression: qualified expression
26    """
27    sequence = itertools.count()
28
29    next_name = lambda: f"_q_{next(sequence)}"
30
31    for scope in traverse_scope(expression):
32        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
33            if not derived_table.args.get("alias"):
34                alias_ = f"_q_{next(sequence)}"
35                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
36                scope.rename_source(None, alias_)
37
38        for name, source in scope.sources.items():
39            if isinstance(source, exp.Table):
40                if isinstance(source.this, exp.Identifier):
41                    if not source.args.get("db"):
42                        source.set("db", exp.to_identifier(db))
43                    if not source.args.get("catalog"):
44                        source.set("catalog", exp.to_identifier(catalog))
45
46                if not source.alias:
47                    source = source.replace(
48                        alias(
49                            source.copy(),
50                            name if name else next_name(),
51                            table=True,
52                        )
53                    )
54
55                if schema and isinstance(source.this, exp.ReadCSV):
56                    with csv_reader(source.this) as reader:
57                        header = next(reader)
58                        columns = next(reader)
59                        schema.add_table(
60                            source, {k: type(v).__name__ for k, v in zip(header, columns)}
61                        )
62            elif isinstance(source, Scope) and source.is_udtf:
63                udtf = source.expression
64                table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
65                udtf.set("alias", table_alias)
66
67                if not table_alias.name:
68                    table_alias.set("this", next_name())
69                if isinstance(udtf, exp.Values) and not table_alias.columns:
70                    for i, e in enumerate(udtf.expressions[0].expressions):
71                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
72
73    return expression

Rewrite sqlglot AST to have fully qualified tables.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
Arguments:
  • expression (sqlglot.Expression): expression to qualify
  • db (str): Database name
  • catalog (str): Catalog name
  • schema: A schema to populate
Returns:

sqlglot.Expression: qualified expression