#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (C) 2011 # Brett Alistair Kromkamp - brettkromkamp@gmail.com # Copyright (C) 2012-2017 # Xiaming Chen - chenxm35@gmail.com # and other contributors. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Tree structure in `treelib`. The :class:`Tree` object defines the tree-like structure based on :class:`Node` objects. A new tree can be created from scratch without any parameter or a shallow/deep copy of another tree. When deep=True, a deepcopy operation is performed on feeding tree parameter and more memory is required to create the tree. """ from __future__ import print_function from __future__ import unicode_literals try: from builtins import str as text except ImportError: from __builtin__ import str as text import codecs import json import uuid from copy import deepcopy from six import python_2_unicode_compatible, iteritems try: from StringIO import StringIO except ImportError: from io import StringIO from .exceptions import ( NodeIDAbsentError, DuplicatedNodeIdError, MultipleRootError, InvalidLevelNumber, LinkPastRootNodeError, LoopError, ) from .node import Node __author__ = "chenxm" @python_2_unicode_compatible class Tree(object): """Tree objects are made of Node(s) stored in _nodes dictionary.""" #: ROOT, DEPTH, WIDTH, ZIGZAG constants : (ROOT, DEPTH, WIDTH, ZIGZAG) = list(range(4)) node_class = Node def __contains__(self, identifier): return identifier in self.nodes.keys() def __init__(self, tree=None, deep=False, node_class=None, identifier=None): """Initiate a new tree or copy another tree with a shallow or deep copy. """ self._identifier = None self._set_identifier(identifier) if node_class: assert issubclass(node_class, Node) self.node_class = node_class #: dictionary, identifier: Node object self._nodes = {} #: Get or set the identifier of the root. This attribute can be accessed and modified #: with ``.`` and ``=`` operator respectively. self.root = None if tree is not None: self.root = tree.root for nid, node in iteritems(tree.nodes): new_node = deepcopy(node) if deep else node self._nodes[nid] = new_node if tree.identifier != self._identifier: new_node.clone_pointers(tree.identifier, self._identifier) def _clone(self, identifier=None, with_tree=False, deep=False): """Clone current instance, with or without tree. Method intended to be overloaded, to avoid rewriting whole "subtree" and "remove_subtree" methods when inheriting from Tree. >>> class TreeWithComposition(Tree): >>> def __init__(self, tree_description, tree=None, deep=False, identifier=None): >>> self.tree_description = tree_description >>> super(TreeWithComposition, self).__init__(tree=tree, deep=deep, identifier=identifier) >>> >>> def _clone(self, identifier=None, with_tree=False, deep=False): >>> return TreeWithComposition( >>> identifier=identifier, >>> deep=deep, >>> tree=self if with_tree else None, >>> tree_description=self.tree_description >>> ) >>> my_custom_tree = TreeWithComposition(tree_description="smart tree") >>> subtree = my_custom_tree.subtree() >>> subtree.tree_description "smart tree" """ return self.__class__( identifier=identifier, tree=self if with_tree else None, deep=deep ) @property def identifier(self): return self._identifier def _set_identifier(self, nid): """Initialize self._set_identifier""" if nid is None: self._identifier = str(uuid.uuid1()) else: self._identifier = nid def __getitem__(self, key): """Return _nodes[key]""" try: return self._nodes[key] except KeyError: raise NodeIDAbsentError("Node '%s' is not in the tree" % key) def __len__(self): """Return len(_nodes)""" return len(self._nodes) def __str__(self): self._reader = "" def write(line): self._reader += line.decode("utf-8") + "\n" self.__print_backend(func=write) return self._reader def __print_backend( self, nid=None, level=ROOT, idhidden=True, filter=None, key=None, reverse=False, line_type="ascii-ex", data_property=None, sorting=True, func=print, ): """ Another implementation of printing tree using Stack Print tree structure in hierarchy style. For example: .. code-block:: bash Root |___ C01 | |___ C11 | |___ C111 | |___ C112 |___ C02 |___ C03 | |___ C31 A more elegant way to achieve this function using Stack structure, for constructing the Nodes Stack push and pop nodes with additional level info. UPDATE: the @key @reverse is present to sort node at each level. """ # Factory for proper get_label() function if data_property: if idhidden: def get_label(node): return getattr(node.data, data_property) else: def get_label(node): return "%s[%s]" % ( getattr(node.data, data_property), node.identifier, ) else: if idhidden: def get_label(node): return node.tag else: def get_label(node): return "%s[%s]" % (node.tag, node.identifier) # legacy ordering if sorting and key is None: def key(node): return node # iter with func for pre, node in self.__get( nid, level, filter, key, reverse, line_type, sorting ): label = get_label(node) func("{0}{1}".format(pre, label).encode("utf-8")) def __get(self, nid, level, filter_, key, reverse, line_type, sorting): # default filter if filter_ is None: def filter_(node): return True # render characters dt = { "ascii": ("|", "|-- ", "+-- "), "ascii-ex": ("\u2502", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 "), "ascii-exr": ("\u2502", "\u251c\u2500\u2500 ", "\u2570\u2500\u2500 "), "ascii-em": ("\u2551", "\u2560\u2550\u2550 ", "\u255a\u2550\u2550 "), "ascii-emv": ("\u2551", "\u255f\u2500\u2500 ", "\u2559\u2500\u2500 "), "ascii-emh": ("\u2502", "\u255e\u2550\u2550 ", "\u2558\u2550\u2550 "), }[line_type] return self.__get_iter(nid, level, filter_, key, reverse, dt, [], sorting) def __get_iter(self, nid, level, filter_, key, reverse, dt, is_last, sorting): dt_vertical_line, dt_line_box, dt_line_corner = dt nid = self.root if nid is None else nid node = self[nid] if level == self.ROOT: yield "", node else: leading = "".join( map( lambda x: dt_vertical_line + " " * 3 if not x else " " * 4, is_last[0:-1], ) ) lasting = dt_line_corner if is_last[-1] else dt_line_box yield leading + lasting, node if filter_(node) and node.expanded: children = [ self[i] for i in node.successors(self._identifier) if filter_(self[i]) ] idxlast = len(children) - 1 if sorting: if key: children.sort(key=key, reverse=reverse) elif reverse: children = reversed(children) level += 1 for idx, child in enumerate(children): is_last.append(idx == idxlast) for item in self.__get_iter( child.identifier, level, filter_, key, reverse, dt, is_last, sorting ): yield item is_last.pop() def __update_bpointer(self, nid, parent_id): """set self[nid].bpointer""" self[nid].set_predecessor(parent_id, self._identifier) def __update_fpointer(self, nid, child_id, mode): if nid is None: return else: self[nid].update_successors(child_id, mode, tree_id=self._identifier) def add_node(self, node, parent=None): """ Add a new node object to the tree and make the parent as the root by default. The 'node' parameter refers to an instance of Class::Node. """ if not isinstance(node, self.node_class): raise OSError( "First parameter must be object of {}".format(self.node_class) ) if node.identifier in self._nodes: raise DuplicatedNodeIdError( "Can't create node " "with ID '%s'" % node.identifier ) pid = parent.identifier if isinstance(parent, self.node_class) else parent if pid is None: if self.root is not None: raise MultipleRootError("A tree takes one root merely.") else: self.root = node.identifier elif not self.contains(pid): raise NodeIDAbsentError("Parent node '%s' " "is not in the tree" % pid) self._nodes.update({node.identifier: node}) self.__update_fpointer(pid, node.identifier, self.node_class.ADD) self.__update_bpointer(node.identifier, pid) node.set_initial_tree_id(self._identifier) def all_nodes(self): """Return all nodes in a list""" return list(self._nodes.values()) def all_nodes_itr(self): """ Returns all nodes in an iterator. Added by William Rusnack """ return self._nodes.values() def ancestor(self, nid, level=None): """ For a given id, get ancestor node object at a given level. If no level is provided, the parent node is returned. """ if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) descendant = self[nid] ascendant = self[nid]._predecessor[self._identifier] ascendant_level = self.level(ascendant) if level is None: return ascendant elif nid == self.root: return self[nid] elif level >= self.level(descendant.identifier): raise InvalidLevelNumber( "Descendant level (level %s) must be greater \ than its ancestor's level (level %s)" % (str(self.level(descendant.identifier)), level) ) while ascendant is not None: if ascendant_level == level: return self[ascendant] else: descendant = ascendant ascendant = self[descendant]._predecessor[self._identifier] ascendant_level = self.level(ascendant) return None def children(self, nid): """ Return the children (Node) list of nid. Empty list is returned if nid does not exist """ return [self[i] for i in self.is_branch(nid)] def contains(self, nid): """Check if the tree contains node of given id""" return True if nid in self._nodes else False def create_node(self, tag=None, identifier=None, parent=None, data=None): """ Create a child node for given @parent node. If ``identifier`` is absent, a UUID will be generated automatically. """ node = self.node_class(tag=tag, identifier=identifier, data=data) self.add_node(node, parent) return node def depth(self, node=None): """ Get the maximum level of this tree or the level of the given node. @param node Node instance or identifier @return int @throw NodeIDAbsentError """ ret = 0 if node is None: # Get maximum level of this tree leaves = self.leaves() for leave in leaves: level = self.level(leave.identifier) ret = level if level >= ret else ret else: # Get level of the given node if not isinstance(node, self.node_class): nid = node else: nid = node.identifier if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) ret = self.level(nid) return ret def expand_tree( self, nid=None, mode=DEPTH, filter=None, key=None, reverse=False, sorting=True ): """ Python generator to traverse the tree (or a subtree) with optional node filtering and sorting. Loosely based on an algorithm from 'Essential LISP' by John R. Anderson, Albert T. Corbett, and Brian J. Reiser, page 239-241. :param nid: Node identifier from which tree traversal will start. If None tree root will be used :param mode: Traversal mode, may be either DEPTH, WIDTH or ZIGZAG :param filter: the @filter function is performed on Node object during traversing. In this manner, the traversing will NOT visit the node whose condition does not pass the filter and its children. :param key: the @key and @reverse are present to sort nodes at each level. If @key is None sorting is performed on node tag. :param reverse: if True reverse sorting :param sorting: if True perform node sorting, if False return nodes in original insertion order. In latter case @key and @reverse parameters are ignored. :return: Node IDs that satisfy the conditions :rtype: generator object """ nid = self.root if nid is None else nid if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) filter = (lambda x: True) if (filter is None) else filter if filter(self[nid]): yield nid queue = [ self[i] for i in self[nid].successors(self._identifier) if filter(self[i]) ] if mode in [self.DEPTH, self.WIDTH]: if sorting: queue.sort(key=key, reverse=reverse) while queue: yield queue[0].identifier expansion = [ self[i] for i in queue[0].successors(self._identifier) if filter(self[i]) ] if sorting: expansion.sort(key=key, reverse=reverse) if mode is self.DEPTH: queue = expansion + queue[1:] # depth-first elif mode is self.WIDTH: queue = queue[1:] + expansion # width-first elif mode is self.ZIGZAG: # Suggested by Ilya Kuprik (ilya-spy@ynadex.ru). stack_fw = [] queue.reverse() stack = stack_bw = queue direction = False while stack: expansion = [ self[i] for i in stack[0].successors(self._identifier) if filter(self[i]) ] yield stack.pop(0).identifier if direction: expansion.reverse() stack_bw = expansion + stack_bw else: stack_fw = expansion + stack_fw if not stack: direction = not direction stack = stack_fw if direction else stack_bw else: raise ValueError("Traversal mode '{}' is not supported".format(mode)) def filter_nodes(self, func): """ Filters all nodes by function. :param func: is passed one node as an argument and that node is included if function returns true, :return: a filter iterator of the node in python 3 or a list of the nodes in python 2. Added by William Rusnack. """ return filter(func, self.all_nodes_itr()) def get_node(self, nid): """ Get the object of the node with ID of ``nid``. An alternative way is using '[]' operation on the tree. But small difference exists between them: ``get_node()`` will return None if ``nid`` is absent, whereas '[]' will raise ``KeyError``. """ if nid is None or not self.contains(nid): return None return self._nodes[nid] def is_branch(self, nid): """ Return the children (ID) list of nid. Empty list is returned if nid does not exist """ if nid is None: raise OSError("First parameter can't be None") if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) try: fpointer = self[nid].successors(self._identifier) except KeyError: fpointer = [] return fpointer def leaves(self, nid=None): """Get leaves of the whole tree or a subtree.""" leaves = [] if nid is None: for node in self._nodes.values(): if node.is_leaf(self._identifier): leaves.append(node) else: for node in self.expand_tree(nid): if self[node].is_leaf(self._identifier): leaves.append(self[node]) return leaves def level(self, nid, filter=None): """ Get the node level in this tree. The level is an integer starting with '0' at the root. In other words, the root lives at level '0'; Update: @filter params is added to calculate level passing exclusive nodes. """ return len([n for n in self.rsearch(nid, filter)]) - 1 def link_past_node(self, nid): """ Delete a node by linking past it. For example, if we have `a -> b -> c` and delete node b, we are left with `a -> c`. """ if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) if self.root == nid: raise LinkPastRootNodeError( "Cannot link past the root node, " "delete it with remove_node()" ) # Get the parent of the node we are linking past parent = self[self[nid].predecessor(self._identifier)] # Set the children of the node to the parent for child in self[nid].successors(self._identifier): self[child].set_predecessor(parent.identifier, self._identifier) # Link the children to the parent for id_ in self[nid].successors(self._identifier) or []: parent.update_successors(id_, tree_id=self._identifier) # Delete the node parent.update_successors(nid, mode=parent.DELETE, tree_id=self._identifier) del self._nodes[nid] def move_node(self, source, destination): """ Move node @source from its parent to another parent @destination. """ if not self.contains(source) or not self.contains(destination): raise NodeIDAbsentError elif self.is_ancestor(source, destination): raise LoopError parent = self[source].predecessor(self._identifier) self.__update_fpointer(parent, source, self.node_class.DELETE) self.__update_fpointer(destination, source, self.node_class.ADD) self.__update_bpointer(source, destination) def is_ancestor(self, ancestor, grandchild): """ Check if the @ancestor the preceding nodes of @grandchild. :param ancestor: the node identifier :param grandchild: the node identifier :return: True or False """ parent = self[grandchild].predecessor(self._identifier) child = grandchild while parent is not None: if parent == ancestor: return True else: child = self[child].predecessor(self._identifier) parent = self[child].predecessor(self._identifier) return False @property def nodes(self): """Return a dict form of nodes in a tree: {id: node_instance}.""" return self._nodes def parent(self, nid): """Get parent :class:`Node` object of given id.""" if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) pid = self[nid].predecessor(self._identifier) if pid is None or not self.contains(pid): return None return self[pid] def merge(self, nid, new_tree, deep=False): """Patch @new_tree on current tree by pasting new_tree root children on current tree @nid node. Consider the following tree: >>> current.show() root ├── A └── B >>> new_tree.show() root2 ├── C └── D └── D1 Merging new_tree on B node: >>>current.merge('B', new_tree) >>>current.show() root ├── A └── B ├── C └── D └── D1 Note: if current tree is empty and nid is None, the new_tree root will be used as root on current tree. In all other cases new_tree root is not pasted. """ if new_tree.root is None: return if nid is None: if self.root is None: new_tree_root = new_tree[new_tree.root] self.add_node(new_tree_root) nid = new_tree.root else: raise ValueError('Must define "nid" under which new tree is merged.') for child in new_tree.children(new_tree.root): self.paste(nid=nid, new_tree=new_tree.subtree(child.identifier), deep=deep) def paste(self, nid, new_tree, deep=False): """ Paste a @new_tree to the original one by linking the root of new tree to given node (nid). Update: add @deep copy of pasted tree. """ assert isinstance(new_tree, Tree) if new_tree.root is None: return if nid is None: raise ValueError('Must define "nid" under which new tree is pasted.') if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) set_joint = set(new_tree._nodes) & set(self._nodes) # joint keys if set_joint: raise ValueError("Duplicated nodes %s exists." % list(map(text, set_joint))) for cid, node in iteritems(new_tree.nodes): if deep: node = deepcopy(new_tree[node]) self._nodes.update({cid: node}) node.clone_pointers(new_tree.identifier, self._identifier) self.__update_bpointer(new_tree.root, nid) self.__update_fpointer(nid, new_tree.root, self.node_class.ADD) def paths_to_leaves(self): """ Use this function to get the identifiers allowing to go from the root nodes to each leaf. :return: a list of list of identifiers, root being not omitted. For example: .. code-block:: python Harry |___ Bill |___ Jane | |___ Diane | |___ George | |___ Jill | |___ Mary | |___ Mark Expected result: .. code-block:: python [['harry', 'jane', 'diane', 'mary'], ['harry', 'jane', 'mark'], ['harry', 'jane', 'diane', 'george', 'jill'], ['harry', 'bill']] """ res = [] for leaf in self.leaves(): res.append([nid for nid in self.rsearch(leaf.identifier)][::-1]) return res def remove_node(self, identifier): """Remove a node indicated by 'identifier' with all its successors. Return the number of removed nodes. """ if not self.contains(identifier): raise NodeIDAbsentError("Node '%s' " "is not in the tree" % identifier) parent = self[identifier].predecessor(self._identifier) # Remove node and its children removed = list(self.expand_tree(identifier)) for id_ in removed: if id_ == self.root: self.root = None self.__update_bpointer(id_, None) for cid in self[id_].successors(self._identifier) or []: self.__update_fpointer(id_, cid, self.node_class.DELETE) # Update parent info self.__update_fpointer(parent, identifier, self.node_class.DELETE) self.__update_bpointer(identifier, None) for id_ in removed: self.nodes.pop(id_) return len(removed) def remove_subtree(self, nid, identifier=None): """ Get a subtree with ``nid`` being the root. If nid is None, an empty tree is returned. For the original tree, this method is similar to `remove_node(self,nid)`, because given node and its children are removed from the original tree in both methods. For the returned value and performance, these two methods are different: * `remove_node` returns the number of deleted nodes; * `remove_subtree` returns a subtree of deleted nodes; You are always suggested to use `remove_node` if your only to delete nodes from a tree, as the other one need memory allocation to store the new tree. :return: a :class:`Tree` object. """ st = self._clone(identifier) if nid is None: return st if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) st.root = nid # in original tree, the removed nid will be unreferenced from its parents children parent = self[nid].predecessor(self._identifier) removed = list(self.expand_tree(nid)) for id_ in removed: if id_ == self.root: self.root = None st._nodes.update({id_: self._nodes.pop(id_)}) st[id_].clone_pointers(self._identifier, st.identifier) st[id_].reset_pointers(self._identifier) if id_ == nid: st[id_].set_predecessor(None, st.identifier) self.__update_fpointer(parent, nid, self.node_class.DELETE) return st def rsearch(self, nid, filter=None): """ Traverse the tree branch along the branch from nid to its ancestors (until root). :param filter: the function of one variable to act on the :class:`Node` object. """ if nid is None: return if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) filter = (lambda x: True) if (filter is None) else filter current = nid while current is not None: if filter(self[current]): yield current # subtree() hasn't update the bpointer current = ( self[current].predecessor(self._identifier) if self.root != current else None ) def save2file( self, filename, nid=None, level=ROOT, idhidden=True, filter=None, key=None, reverse=False, line_type="ascii-ex", data_property=None, sorting=True, ): """ Save the tree into file for offline analysis. """ def _write_line(line, f): f.write(line + b"\n") def handler(x): return _write_line(x, open(filename, "ab")) self.__print_backend( nid, level, idhidden, filter, key, reverse, line_type, data_property, sorting, func=handler, ) def show( self, nid=None, level=ROOT, idhidden=True, filter=None, key=None, reverse=False, line_type="ascii-ex", data_property=None, stdout=True, sorting=True, ): """ Print the tree structure in hierarchy style. You have three ways to output your tree data, i.e., stdout with ``show()``, plain text file with ``save2file()``, and json string with ``to_json()``. The former two use the same backend to generate a string of tree structure in a text graph. * Version >= 1.2.7a*: you can also specify the ``line_type`` parameter, such as 'ascii' (default), 'ascii-ex', 'ascii-exr', 'ascii-em', 'ascii-emv', 'ascii-emh') to the change graphical form. :param nid: the reference node to start expanding. :param level: the node level in the tree (root as level 0). :param idhidden: whether hiding the node ID when printing. :param filter: the function of one variable to act on the :class:`Node` object. When this parameter is specified, the traversing will not continue to following children of node whose condition does not pass the filter. :param key: the ``key`` param for sorting :class:`Node` objects in the same level. :param reverse: the ``reverse`` param for sorting :class:`Node` objects in the same level. :param line_type: :param data_property: the property on the node data object to be printed. :param sorting: if True perform node sorting, if False return nodes in original insertion order. In latter case @key and @reverse parameters are ignored. :return: None """ self._reader = "" def write(line): self._reader += line.decode("utf-8") + "\n" try: self.__print_backend( nid, level, idhidden, filter, key, reverse, line_type, data_property, sorting, func=write, ) except NodeIDAbsentError: print("Tree is empty") if stdout: print(self._reader.encode("utf-8")) else: return self._reader def siblings(self, nid): """ Return the siblings of given @nid. If @nid is root or there are no siblings, an empty list is returned. """ siblings = [] if nid != self.root: pid = self[nid].predecessor(self._identifier) siblings = [ self[i] for i in self[pid].successors(self._identifier) if i != nid ] return siblings def size(self, level=None): """ Get the number of nodes of the whole tree if @level is not given. Otherwise, the total number of nodes at specific level is returned. @param level The level number in the tree. It must be between [0, tree.depth]. Otherwise, InvalidLevelNumber exception will be raised. """ if level is None: return len(self._nodes) else: try: level = int(level) return len( [ node for node in self.all_nodes_itr() if self.level(node.identifier) == level ] ) except Exception: raise TypeError( "level should be an integer instead of '%s'" % type(level) ) def subtree(self, nid, identifier=None): """ Return a shallow COPY of subtree with nid being the new root. If nid is None, return an empty tree. If you are looking for a deepcopy, please create a new tree with this shallow copy, e.g., .. code-block:: python new_tree = Tree(t.subtree(t.root), deep=True) This line creates a deep copy of the entire tree. """ st = self._clone(identifier) if nid is None: return st if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) st.root = nid for node_n in self.expand_tree(nid): st._nodes.update({self[node_n].identifier: self[node_n]}) # define nodes parent/children in this tree # all pointers are the same as copied tree, except the root st[node_n].clone_pointers(self._identifier, st.identifier) if node_n == nid: # reset root parent for the new tree st[node_n].set_predecessor(None, st.identifier) return st def update_node(self, nid, **attrs): """ Update node's attributes. :param nid: the identifier of modified node :param attrs: attribute pairs recognized by Node object :return: None """ cn = self[nid] for attr, val in iteritems(attrs): if attr == "identifier": # Updating node id meets following contraints: # * Update node identifier property # * Update parent's followers # * Update children's parents # * Update tree registration of var _nodes # * Update tree root if necessary cn = self._nodes.pop(nid) setattr(cn, "identifier", val) self._nodes[val] = cn if cn.predecessor(self._identifier) is not None: self[cn.predecessor(self._identifier)].update_successors( nid, mode=self.node_class.REPLACE, replace=val, tree_id=self._identifier, ) for fp in cn.successors(self._identifier): self[fp].set_predecessor(val, self._identifier) if self.root == nid: self.root = val else: setattr(cn, attr, val) def to_dict(self, nid=None, key=None, sort=True, reverse=False, with_data=False): """Transform the whole tree into a dict.""" nid = self.root if (nid is None) else nid ntag = self[nid].tag tree_dict = {ntag: {"children": []}} if with_data: tree_dict[ntag]["data"] = self[nid].data if self[nid].expanded: queue = [self[i] for i in self[nid].successors(self._identifier)] key = (lambda x: x) if (key is None) else key if sort: queue.sort(key=key, reverse=reverse) for elem in queue: tree_dict[ntag]["children"].append( self.to_dict( elem.identifier, with_data=with_data, sort=sort, reverse=reverse ) ) if len(tree_dict[ntag]["children"]) == 0: tree_dict = ( self[nid].tag if not with_data else {ntag: {"data": self[nid].data}} ) return tree_dict def to_json(self, with_data=False, sort=True, reverse=False): """To format the tree in JSON format.""" return json.dumps(self.to_dict(with_data=with_data, sort=sort, reverse=reverse)) def to_graphviz( self, filename=None, shape="circle", graph="digraph", filter=None, key=None, reverse=False, sorting=True, ): """Exports the tree in the dot format of the graphviz software""" nodes, connections = [], [] if self.nodes: for n in self.expand_tree( mode=self.WIDTH, filter=filter, key=key, reverse=reverse, sorting=sorting, ): nid = self[n].identifier state = '"{0}" [label="{1}", shape={2}]'.format(nid, self[n].tag, shape) nodes.append(state) for c in self.children(nid): cid = c.identifier edge = "->" if graph == "digraph" else "--" connections.append(('"{0}" ' + edge + ' "{1}"').format(nid, cid)) # write nodes and connections to dot format is_plain_file = filename is not None if is_plain_file: f = codecs.open(filename, "w", "utf-8") else: f = StringIO() f.write(graph + " tree {\n") for n in nodes: f.write("\t" + n + "\n") if len(connections) > 0: f.write("\n") for c in connections: f.write("\t" + c + "\n") f.write("}") if not is_plain_file: print(f.getvalue()) f.close() @classmethod def from_map(cls, child_parent_dict, id_func=None, data_func=None): """ takes a dict with child:parent, and form a tree """ tree = Tree() if tree is None or tree.size() > 0: raise ValueError("need to pass in an empty tree") id_func = id_func if id_func else lambda x: x data_func = data_func if data_func else lambda x: None parent_child_dict = {} root_node = None for k, v in child_parent_dict.items(): if v is None and root_node is None: root_node = k elif v is None and root_node is not None: raise ValueError("invalid input, more than 1 child has no parent") else: if v in parent_child_dict: parent_child_dict[v].append(k) else: parent_child_dict[v] = [k] if root_node is None: raise ValueError("cannot find root") tree.create_node(root_node, id_func(root_node), data=data_func(root_node)) queue = [root_node] while len(queue) > 0: parent_node = queue.pop() for child in parent_child_dict.get(parent_node, []): tree.create_node( child, id_func(child), parent=id_func(parent_node), data=data_func(child), ) queue.append(child) return tree