summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/normalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/normalize.py')
-rw-r--r--sqlglot/dataframe/sql/normalize.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index 1513946..75feba7 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
-def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+def replace_alias_name_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
@@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
- join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
- ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ join_table_aliases = [
+ x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
+ ]
+ ctes_in_join = [
+ cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
+ ]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name)
@@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
- values = ensure_list(values)
results = []
for value in values:
if isinstance(value, str):