summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/normalize.py
blob: 15139468353a3e5ac0e221e5debf7df2322b1f20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

import typing as t

from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.helper import ensure_list

NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql.session import SparkSession


def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
    expr = ensure_list(expr)
    expressions = _ensure_expressions(expr)
    for expression in expressions:
        identifiers = expression.find_all(exp.Identifier)
        for identifier in identifiers:
            replace_alias_name_with_cte_name(spark, expression_context, identifier)
            replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)


def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
    if id.alias_or_name in spark.name_to_sequence_id_mapping:
        for cte in reversed(expression_context.ctes):
            if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
                _set_alias_name(id, cte.alias_or_name)
                break


def replace_branch_and_sequence_ids_with_cte_name(
    spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
):
    if id.alias_or_name in spark.known_ids:
        # Check if we have a join and if both the tables in that join share a common branch id
        # If so we need to have this reference the left table by default unless the id is a sequence
        # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
        # be common in practice
        if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
            join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
            ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
            if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
                assert len(ctes_in_join) == 2
                _set_alias_name(id, ctes_in_join[0].alias_or_name)
                return

        for cte in reversed(expression_context.ctes):
            if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
                _set_alias_name(id, cte.alias_or_name)
                return


def _set_alias_name(id: exp.Identifier, name: str):
    id.set("this", name)


def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
    values = ensure_list(values)
    results = []
    for value in values:
        if isinstance(value, str):
            results.append(Column.ensure_col(value).expression)
        elif isinstance(value, Column):
            results.append(value.expression)
        elif isinstance(value, exp.Expression):
            results.append(value)
        else:
            raise ValueError(f"Got an invalid type to normalize: {type(value)}")
    return results