summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_ctes.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/eliminate_ctes.py')
-rw-r--r--sqlglot/optimizer/eliminate_ctes.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py
new file mode 100644
index 0000000..7b862c6
--- /dev/null
+++ b/sqlglot/optimizer/eliminate_ctes.py
@@ -0,0 +1,42 @@
+from sqlglot.optimizer.scope import Scope, build_scope
+
+
+def eliminate_ctes(expression):
+ """
+ Remove unused CTEs from an expression.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> eliminate_ctes(expression).sql()
+ 'SELECT a FROM z'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ root = build_scope(expression)
+
+ ref_count = root.ref_count()
+
+ # Traverse the scope tree in reverse so we can remove chains of unused CTEs
+ for scope in reversed(list(root.traverse())):
+ if scope.is_cte:
+ count = ref_count[id(scope)]
+ if count <= 0:
+ cte_node = scope.expression.parent
+ with_node = cte_node.parent
+ cte_node.pop()
+
+ # Pop the entire WITH clause if this is the last CTE
+ if len(with_node.expressions) <= 0:
+ with_node.pop()
+
+ # Decrement the ref count for all sources this CTE selects from
+ for _, source in scope.selected_sources.values():
+ if isinstance(source, Scope):
+ ref_count[id(source)] -= 1
+
+ return expression