summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/normalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/normalize.py')
-rw-r--r--sqlglot/dataframe/sql/normalize.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
new file mode 100644
index 0000000..1513946
--- /dev/null
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.helper import ensure_list
+
+NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
+ expr = ensure_list(expr)
+ expressions = _ensure_expressions(expr)
+ for expression in expressions:
+ identifiers = expression.find_all(exp.Identifier)
+ for identifier in identifiers:
+ replace_alias_name_with_cte_name(spark, expression_context, identifier)
+ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
+
+
+def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+ if id.alias_or_name in spark.name_to_sequence_id_mapping:
+ for cte in reversed(expression_context.ctes):
+ if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
+ _set_alias_name(id, cte.alias_or_name)
+ break
+
+
+def replace_branch_and_sequence_ids_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
+ if id.alias_or_name in spark.known_ids:
+ # Check if we have a join and if both the tables in that join share a common branch id
+ # If so we need to have this reference the left table by default unless the id is a sequence
+ # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
+ # be common in practice
+ if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
+ join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
+ ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
+ assert len(ctes_in_join) == 2
+ _set_alias_name(id, ctes_in_join[0].alias_or_name)
+ return
+
+ for cte in reversed(expression_context.ctes):
+ if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
+ _set_alias_name(id, cte.alias_or_name)
+ return
+
+
+def _set_alias_name(id: exp.Identifier, name: str):
+ id.set("this", name)
+
+
+def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
+ values = ensure_list(values)
+ results = []
+ for value in values:
+ if isinstance(value, str):
+ results.append(Column.ensure_col(value).expression)
+ elif isinstance(value, Column):
+ results.append(value.expression)
+ elif isinstance(value, exp.Expression):
+ results.append(value)
+ else:
+ raise ValueError(f"Got an invalid type to normalize: {type(value)}")
+ return results