diff options
Diffstat (limited to 'sqlglot/optimizer/eliminate_ctes.py')
-rw-r--r-- | sqlglot/optimizer/eliminate_ctes.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py new file mode 100644 index 0000000..7b862c6 --- /dev/null +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -0,0 +1,42 @@ +from sqlglot.optimizer.scope import Scope, build_scope + + +def eliminate_ctes(expression): + """ + Remove unused CTEs from an expression. + + Example: + >>> import sqlglot + >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_ctes(expression).sql() + 'SELECT a FROM z' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + root = build_scope(expression) + + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 + + return expression |