diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/helper.py | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 4215fee..2f48ab5 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> return expression -def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: +def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: """ Sorts a given directed acyclic graph in topological order. @@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: """ result = [] - def visit(node: T, visited: t.Set[T]) -> None: - if node in result: - return - if node in visited: - raise ValueError("Cycle error") + for node, deps in tuple(dag.items()): + for dep in deps: + if not dep in dag: + dag[dep] = set() + + while dag: + current = {node for node, deps in dag.items() if not deps} - visited.add(node) + if not current: + raise ValueError("Cycle error") - for dep in dag.get(node, []): - visit(dep, visited) + for node in current: + dag.pop(node) - visited.remove(node) - result.append(node) + for deps in dag.values(): + deps -= current - for node in dag: - visit(node, set()) + result.extend(sorted(current)) # type: ignore return result |