summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r--sqlglot/helper.py28
1 files changed, 15 insertions, 13 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 4215fee..2f48ab5 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
return expression
-def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
+def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
"""
Sorts a given directed acyclic graph in topological order.
@@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
"""
result = []
- def visit(node: T, visited: t.Set[T]) -> None:
- if node in result:
- return
- if node in visited:
- raise ValueError("Cycle error")
+ for node, deps in tuple(dag.items()):
+ for dep in deps:
+ if not dep in dag:
+ dag[dep] = set()
+
+ while dag:
+ current = {node for node, deps in dag.items() if not deps}
- visited.add(node)
+ if not current:
+ raise ValueError("Cycle error")
- for dep in dag.get(node, []):
- visit(dep, visited)
+ for node in current:
+ dag.pop(node)
- visited.remove(node)
- result.append(node)
+ for deps in dag.values():
+ deps -= current
- for node in dag:
- visit(node, set())
+ result.extend(sorted(current)) # type: ignore
return result