diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 1513946..75feba7 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) -def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): +def replace_alias_name_with_cte_name( + spark: SparkSession, expression_context: exp.Select, id: exp.Identifier +): if id.alias_or_name in spark.name_to_sequence_id_mapping: for cte in reversed(expression_context.ctes): if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: @@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name( # id then it keeps that reference. This handles the weird edge case in spark that shouldn't # be common in practice if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: - join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] - ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] + join_table_aliases = [ + x.alias_or_name for x in get_tables_from_expression_with_join(expression_context) + ] + ctes_in_join = [ + cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases + ] if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: assert len(ctes_in_join) == 2 _set_alias_name(id, ctes_in_join[0].alias_or_name) @@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str): def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: - values = ensure_list(values) results = [] for value in values: if isinstance(value, str): |