summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/eliminate_ctes.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/eliminate_ctes.py')
-rw-r--r--sqlglot/optimizer/eliminate_ctes.py39
1 files changed, 20 insertions, 19 deletions
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py
index 7b862c6..6f1865c 100644
--- a/sqlglot/optimizer/eliminate_ctes.py
+++ b/sqlglot/optimizer/eliminate_ctes.py
@@ -19,24 +19,25 @@ def eliminate_ctes(expression):
"""
root = build_scope(expression)
- ref_count = root.ref_count()
-
- # Traverse the scope tree in reverse so we can remove chains of unused CTEs
- for scope in reversed(list(root.traverse())):
- if scope.is_cte:
- count = ref_count[id(scope)]
- if count <= 0:
- cte_node = scope.expression.parent
- with_node = cte_node.parent
- cte_node.pop()
-
- # Pop the entire WITH clause if this is the last CTE
- if len(with_node.expressions) <= 0:
- with_node.pop()
-
- # Decrement the ref count for all sources this CTE selects from
- for _, source in scope.selected_sources.values():
- if isinstance(source, Scope):
- ref_count[id(source)] -= 1
+ if root:
+ ref_count = root.ref_count()
+
+ # Traverse the scope tree in reverse so we can remove chains of unused CTEs
+ for scope in reversed(list(root.traverse())):
+ if scope.is_cte:
+ count = ref_count[id(scope)]
+ if count <= 0:
+ cte_node = scope.expression.parent
+ with_node = cte_node.parent
+ cte_node.pop()
+
+ # Pop the entire WITH clause if this is the last CTE
+ if len(with_node.expressions) <= 0:
+ with_node.pop()
+
+ # Decrement the ref count for all sources this CTE selects from
+ for _, source in scope.selected_sources.values():
+ if isinstance(source, Scope):
+ ref_count[id(source)] -= 1
return expression