summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/transforms.py70
1 files changed, 63 insertions, 7 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index e0fd68f..445fda6 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -67,7 +67,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
order = expression.args.get("order")
if order:
- window.set("order", order.pop().copy())
+ window.set("order", order.pop())
else:
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
@@ -75,9 +75,9 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
expression.select(window, copy=False)
return (
- exp.select(*outer_selects)
- .from_(expression.subquery("_t"))
- .where(exp.column(row_number).eq(1))
+ exp.select(*outer_selects, copy=False)
+ .from_(expression.subquery("_t", copy=False), copy=False)
+ .where(exp.column(row_number).eq(1), copy=False)
)
return expression
@@ -120,7 +120,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
- return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
+ return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
+ qualify_filters, copy=False
+ )
return expression
@@ -189,7 +191,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
)
# we use list here because expression.selects is mutated inside the loop
- for select in expression.selects.copy():
+ for select in list(expression.selects):
explode = select.find(exp.Explode)
if explode:
@@ -374,6 +376,60 @@ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
return expression
+def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
+ """
+ Converts a query with a FULL OUTER join to a union of identical queries that
+ use LEFT/RIGHT OUTER joins instead. This transformation currently only works
+ for queries that have a single FULL OUTER join.
+ """
+ if isinstance(expression, exp.Select):
+ full_outer_joins = [
+ (index, join)
+ for index, join in enumerate(expression.args.get("joins") or [])
+ if join.side == "FULL" and join.kind == "OUTER"
+ ]
+
+ if len(full_outer_joins) == 1:
+ expression_copy = expression.copy()
+ index, full_outer_join = full_outer_joins[0]
+ full_outer_join.set("side", "left")
+ expression_copy.args["joins"][index].set("side", "right")
+
+ return exp.union(expression, expression_copy, copy=False)
+
+ return expression
+
+
+def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
+ """
+ Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
+ defined at the top-level, so for example queries like:
+
+ SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
+
+ are invalid in those dialects. This transformation can be used to ensure all CTEs are
+ moved to the top level so that the final SQL code is valid from a syntax standpoint.
+
+ TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
+ """
+ top_level_with = expression.args.get("with")
+ for node in expression.find_all(exp.With):
+ if node.parent is expression:
+ continue
+
+ inner_with = node.pop()
+ if not top_level_with:
+ top_level_with = inner_with
+ expression.set("with", top_level_with)
+ else:
+ if inner_with.recursive:
+ top_level_with.set("recursive", True)
+
+ top_level_with.expressions.extend(inner_with.expressions)
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
@@ -392,7 +448,7 @@ def preprocess(
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
- expression = transforms[0](expression.copy())
+ expression = transforms[0](expression)
for t in transforms[1:]:
expression = t(expression)