Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class EquipmentTreeBuilder(StepActionWithContextValue):
"""

_roots: dict[ConductingEquipment, EquipmentTreeNode] = {}
_leaves: set[EquipmentTreeNode] = set()

def __init__(self):
super().__init__(key=str(uuid.uuid4()))
Expand All @@ -43,6 +44,26 @@ def __init__(self):
def roots(self) -> Generator[TreeNode[ConductingEquipment], None, None]:
return (r for r in self._roots.values())

def recurse_nodes(self) -> Generator[TreeNode[ConductingEquipment], None, None]:
"""
Returns a generator that will yield every node in the tree structure.
"""
def recurse(node: TreeNode[ConductingEquipment]):
yield node
for child in node.children:
yield from recurse(child)

for root in self._roots.values():
yield from recurse(root)

@property
def leaves(self) -> set[EquipmentTreeNode]:
"""
Return the leaves of the tree structure. Depending on how the backing trace is configured,
there may be extra unexpected leaves in loops.
"""
return set(self._leaves)

def compute_initial_value(self, item: NetworkTraceStep[Any]) -> EquipmentTreeNode:
node = self._roots.get(item.path.to_equipment)
if node is None:
Expand All @@ -64,7 +85,9 @@ def compute_next_value(

def _apply(self, item: NetworkTraceStep[Any], context: StepContext):
current_node: TreeNode = self.get_context_value(context)
self._leaves.add(current_node) # add this node to _leaves as it has no children
if current_node.parent:
self._leaves.discard(current_node.parent) # this nodes parent now has a child, it's not a leaf anymore
current_node.parent.add_child(current_node)

def clear(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,31 @@

__all__ = ['TreeNode']

from typing import List, TypeVar, Generic

from zepben.ewb import IdentifiedObject
from typing import TypeVar, Generic, Set

T = TypeVar('T')


class TreeNode(Generic[T]):
"""
represents a node in the NetworkTrace tree
Represents a node in the NetworkTrace tree
"""

def __init__(self, identified_object: IdentifiedObject, parent=None):
def __init__(self, identified_object: T, parent=None):
self.identified_object = identified_object
self._parent: TreeNode = parent
self._children: List[TreeNode] = []
self._children: Set[TreeNode] = set()

@property
def parent(self) -> 'TreeNode[T]':
return self._parent

@property
def children(self):
return list(self._children)
def children(self) -> Set['TreeNode[T]']:
return set(self._children)

def add_child(self, child: 'TreeNode'):
self._children.append(child)
def add_child(self, child: 'TreeNode[T]'):
self._children.add(child)

def __str__(self):
return f"{{object: {self.identified_object}, parent: {self.parent or ''}, num children: {len(self.children)}}}"
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ async def test_downstream_tree():
start = n.get("j1", ConductingEquipment)
assert start is not None
tree_builder = EquipmentTreeBuilder()
trace = Tracing.network_trace_branching(
network_state_operators=normal,
action_step_type=NetworkTraceActionType.FIRST_STEP_ON_EQUIPMENT) \
.add_condition(downstream()) \
.add_step_action(tree_builder) \
trace = (
Tracing.network_trace_branching(
network_state_operators=normal,
action_step_type=NetworkTraceActionType.FIRST_STEP_ON_EQUIPMENT
)
.add_condition(downstream())
.add_step_action(tree_builder)
.add_step_action(lambda item, context: visited_ce.append(item.path.to_equipment.mrid))
)

await trace.run(start)

Expand All @@ -51,34 +54,39 @@ async def test_downstream_tree():

pprint.pprint(visit_counts)

root = list(tree_builder.roots)[0]
root = tree_builder._roots[start]

assert root is not None
_verify_tree_asset(root, n["j1"], None, [n["ac1"], n["ac3"]])

test_node = root.children[0]
_verify_tree_asset(test_node, n["ac1"], n["j1"], [n["j2"]])
assert len(root.children) == 2
for test_node in root.children:
if test_node.identified_object == n['ac1']:
_verify_tree_asset(test_node, n["ac1"], n["j1"], [n["j2"]])

test_node = test_node.children[0]
_verify_tree_asset(test_node, n["j2"], n["ac1"], [n["ac2"]])
test_node = test_node.children.pop()
_verify_tree_asset(test_node, n["j2"], n["ac1"], [n["ac2"]])

test_node = test_node.children[0]
_verify_tree_asset(test_node, n["ac2"], n["j2"], [n["j3"]])
test_node = test_node.children.pop()
_verify_tree_asset(test_node, n["ac2"], n["j2"], [n["j3"]])

test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["j3"], n["ac2"], [n["ac4"]])
test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["j3"], n["ac2"], [n["ac4"]])

test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["ac4"], n["j3"], [n["j6"]])
test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["ac4"], n["j3"], [n["j6"]])

test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["j6"], n["ac4"], [])
test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["j6"], n["ac4"], [])
break

test_node = list(root.children)[1]
_verify_tree_asset(test_node, n["ac3"], n["j1"], [n["j4"]])
elif test_node.identified_object == n['ac3']:
_verify_tree_asset(test_node, n["ac3"], n["j1"], [n["j4"]])

test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["j4"], n["ac3"], [n["ac5"], n["ac6"]])
test_node = next(iter(test_node.children))
_verify_tree_asset(test_node, n["j4"], n["ac3"], [n["ac5"], n["ac6"]])
else:
assert False

assert len(_find_nodes(root, "j0")) == 0
assert len(_find_nodes(root, "ac0")) == 0
Expand Down Expand Up @@ -147,6 +155,10 @@ async def test_downstream_tree():
assert _find_node_depths(root, "ac16") == [8, 9, 11, 14]


for ce in (n['j5'], n['j13']):
assert ce in {l.identified_object for l in tree_builder.leaves}


def _verify_tree_asset(
tree_node: TreeNode,
expected_asset: Optional[ConductingEquipment],
Expand All @@ -162,8 +174,15 @@ def _verify_tree_asset(
else:
assert tree_node.parent is None

children_nodes = list(c.identified_object for c in tree_node.children)
assert children_nodes == expected_children
children_nodes = [c.identified_object for c in tree_node.children]
try:
for child in expected_children:
assert child in children_nodes
for child in children_nodes:
assert child in expected_children
except AssertionError as e:
e.args = (expected_children, children_nodes)
raise e


def _find_nodes(root: TreeNode[ConductingEquipment], asset_id: str) -> List[TreeNode[ConductingEquipment]]:
Expand Down