From 97d3673ec2d668050912aa6aea1816885ca6c5ab Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 15 Oct 2022 15:52:53 +0200 Subject: Adding upstream version 7.1.3. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/eliminate_ctes.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 sqlglot/optimizer/eliminate_ctes.py (limited to 'sqlglot/optimizer/eliminate_ctes.py') diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py new file mode 100644 index 0000000..7b862c6 --- /dev/null +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -0,0 +1,42 @@ +from sqlglot.optimizer.scope import Scope, build_scope + + +def eliminate_ctes(expression): + """ + Remove unused CTEs from an expression. + + Example: + >>> import sqlglot + >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_ctes(expression).sql() + 'SELECT a FROM z' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + root = build_scope(expression) + + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 + + return expression -- cgit v1.2.3