summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_subqueries.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/eliminate_subqueries.py')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
new file mode 100644
index 0000000..4bfb733
--- /dev/null
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import alias, exp, select, table
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def eliminate_subqueries(expression):
+ """
+ Rewrite duplicate subqueries from sqlglot AST.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
+ >>> eliminate_subqueries(expression).sql()
+ 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ schema (dict|sqlglot.optimizer.Schema): Database schema
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ expression = simplify(expression)
+ queries = {}
+
+ for scope in traverse_scope(expression):
+ query = scope.expression
+ queries[query] = queries.get(query, []) + [query]
+
+ sequence = itertools.count()
+
+ for query, duplicates in queries.items():
+ if len(duplicates) == 1:
+ continue
+
+ alias_ = f"_e_{next(sequence)}"
+
+ for dup in duplicates:
+ parent = dup.parent
+ if isinstance(parent, exp.Subquery):
+ parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
+ elif isinstance(parent, exp.Union):
+ dup.replace(select("*").from_(alias_))
+
+ expression.with_(alias_, as_=query, copy=False)
+
+ return expression