diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/transforms.py | 70 |
1 files changed, 63 insertions, 7 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index e0fd68f..445fda6 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -67,7 +67,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: order = expression.args.get("order") if order: - window.set("order", order.pop().copy()) + window.set("order", order.pop()) else: window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) @@ -75,9 +75,9 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: expression.select(window, copy=False) return ( - exp.select(*outer_selects) - .from_(expression.subquery("_t")) - .where(exp.column(row_number).eq(1)) + exp.select(*outer_selects, copy=False) + .from_(expression.subquery("_t", copy=False), copy=False) + .where(exp.column(row_number).eq(1), copy=False) ) return expression @@ -120,7 +120,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: elif expr.name not in expression.named_selects: expression.select(expr.copy(), copy=False) - return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) + return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( + qualify_filters, copy=False + ) return expression @@ -189,7 +191,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp ) # we use list here because expression.selects is mutated inside the loop - for select in expression.selects.copy(): + for select in list(expression.selects): explode = select.find(exp.Explode) if explode: @@ -374,6 +376,60 @@ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: return expression +def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: + """ + Converts a query with a FULL OUTER join to a union of identical queries that + use LEFT/RIGHT OUTER joins instead. This transformation currently only works + for queries that have a single FULL OUTER join. + """ + if isinstance(expression, exp.Select): + full_outer_joins = [ + (index, join) + for index, join in enumerate(expression.args.get("joins") or []) + if join.side == "FULL" and join.kind == "OUTER" + ] + + if len(full_outer_joins) == 1: + expression_copy = expression.copy() + index, full_outer_join = full_outer_joins[0] + full_outer_join.set("side", "left") + expression_copy.args["joins"][index].set("side", "right") + + return exp.union(expression, expression_copy, copy=False) + + return expression + + +def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: + """ + Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be + defined at the top-level, so for example queries like: + + SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq + + are invalid in those dialects. This transformation can be used to ensure all CTEs are + moved to the top level so that the final SQL code is valid from a syntax standpoint. + + TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). + """ + top_level_with = expression.args.get("with") + for node in expression.find_all(exp.With): + if node.parent is expression: + continue + + inner_with = node.pop() + if not top_level_with: + top_level_with = inner_with + expression.set("with", top_level_with) + else: + if inner_with.recursive: + top_level_with.set("recursive", True) + + top_level_with.expressions.extend(inner_with.expressions) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: @@ -392,7 +448,7 @@ def preprocess( def _to_sql(self, expression: exp.Expression) -> str: expression_type = type(expression) - expression = transforms[0](expression.copy()) + expression = transforms[0](expression) for t in transforms[1:]: expression = t(expression) |