From d0ac4027c234df15597f839d3e2e9aecf8a8fd68 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 14 Nov 2025 15:25:42 +0000 Subject: [PATCH] Quick perf pass --- tstrait/genetic_value.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tstrait/genetic_value.py b/tstrait/genetic_value.py index 30b8954..9358fbe 100644 --- a/tstrait/genetic_value.py +++ b/tstrait/genetic_value.py @@ -10,23 +10,33 @@ def _compute_nodes_genetic_value( left_child_array, right_sib_array, - stack, + seed_nodes, has_mutation, num_nodes, effect_size, ): # pragma: no cover """ Compute the genetic value of each node for the specified set of mutations - encoded in the stack. + encoded in ``seed_nodes``. """ genetic_value = np.zeros(num_nodes) - while len(stack) > 0: - parent_node_id = stack.pop() + stack = np.empty(num_nodes, dtype=np.int64) + top = 0 + + # Initialise stack with seed nodes. + for i in range(seed_nodes.shape[0]): + stack[top] = seed_nodes[i] + top += 1 + + while top > 0: + top -= 1 + parent_node_id = stack[top] genetic_value[parent_node_id] = effect_size child_node_id = left_child_array[parent_node_id] while child_node_id != -1: if not has_mutation[child_node_id]: - stack.append(child_node_id) + stack[top] = child_node_id + top += 1 child_node_id = right_sib_array[child_node_id] return genetic_value @@ -72,18 +82,19 @@ def _individual_genetic_values(self, tree, site, causal_allele, effect_size): for m in site.mutations: state_transitions[m.node] = m.derived_state has_mutation[m.node] = True - stack = numba.typed.List() + seed_nodes = [] for node, allele in state_transitions.items(): if allele == causal_allele: - stack.append(node) + seed_nodes.append(node) - if len(stack) == 0: + if len(seed_nodes) == 0: genetic_value = np.zeros(self.ts.num_nodes) else: + seed_nodes_array = np.array(seed_nodes, dtype=np.int64) genetic_value = _compute_nodes_genetic_value( left_child_array=tree.left_child_array, right_sib_array=tree.right_sib_array, - stack=stack, + seed_nodes=seed_nodes_array, has_mutation=has_mutation, num_nodes=self.ts.num_nodes, effect_size=effect_size,