diff options
Diffstat (limited to 'sqlglot/optimizer/eliminate_ctes.py')
-rw-r--r-- | sqlglot/optimizer/eliminate_ctes.py | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py index 7b862c6..6f1865c 100644 --- a/sqlglot/optimizer/eliminate_ctes.py +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -19,24 +19,25 @@ def eliminate_ctes(expression): """ root = build_scope(expression) - ref_count = root.ref_count() - - # Traverse the scope tree in reverse so we can remove chains of unused CTEs - for scope in reversed(list(root.traverse())): - if scope.is_cte: - count = ref_count[id(scope)] - if count <= 0: - cte_node = scope.expression.parent - with_node = cte_node.parent - cte_node.pop() - - # Pop the entire WITH clause if this is the last CTE - if len(with_node.expressions) <= 0: - with_node.pop() - - # Decrement the ref count for all sources this CTE selects from - for _, source in scope.selected_sources.values(): - if isinstance(source, Scope): - ref_count[id(source)] -= 1 + if root: + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 return expression |