summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:15 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:15 +0000
commit358a09296d7198a4cc142f1976de8f3eb3318e58 (patch)
tree762db96c44014dc4db5e9fc7f6709c138589155e /sqlglot/helper.py
parentAdding upstream version 15.2.0. (diff)
downloadsqlglot-upstream/16.2.1.tar.xz
sqlglot-upstream/16.2.1.zip
Adding upstream version 16.2.1.upstream/16.2.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r--sqlglot/helper.py28
1 files changed, 15 insertions, 13 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 4215fee..2f48ab5 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
return expression
-def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
+def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
"""
Sorts a given directed acyclic graph in topological order.
@@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
"""
result = []
- def visit(node: T, visited: t.Set[T]) -> None:
- if node in result:
- return
- if node in visited:
- raise ValueError("Cycle error")
+ for node, deps in tuple(dag.items()):
+ for dep in deps:
+ if not dep in dag:
+ dag[dep] = set()
+
+ while dag:
+ current = {node for node, deps in dag.items() if not deps}
- visited.add(node)
+ if not current:
+ raise ValueError("Cycle error")
- for dep in dag.get(node, []):
- visit(dep, visited)
+ for node in current:
+ dag.pop(node)
- visited.remove(node)
- result.append(node)
+ for deps in dag.values():
+ deps -= current
- for node in dag:
- visit(node, set())
+ result.extend(sorted(current)) # type: ignore
return result