22Belief Propagation (BP) algorithm implementation using PyTorch.
33"""
44
5- from typing import List , Dict , Tuple , Optional
5+ from typing import List , Dict , Tuple
66import torch
77from copy import deepcopy
88
@@ -88,7 +88,7 @@ def initial_state(bp: BeliefPropagation) -> BPState:
8888 var_messages_in = []
8989 var_messages_out = []
9090
91- for factor_idx in bp .v2t [var_idx ]:
91+ for _ in bp .v2t [var_idx ]:
9292 card = bp .cards [var_idx ]
9393 msg = torch .ones (card , dtype = torch .float64 )
9494 var_messages_in .append (msg .clone ())
@@ -97,7 +97,7 @@ def initial_state(bp: BeliefPropagation) -> BPState:
9797 message_in .append (var_messages_in )
9898 message_out .append (var_messages_out )
9999
100- return BPState (deepcopy ( message_in ) , message_out )
100+ return BPState (message_in , message_out )
101101
102102
103103def _compute_factor_to_var_message (
@@ -124,7 +124,7 @@ def _compute_factor_to_var_message(
124124 return factor_tensor .clone ()
125125
126126 # Multiply factor tensor by incoming messages (excluding target) and sum out dims.
127- result = factor_tensor
127+ result = factor_tensor . clone ()
128128 for dim in range (ndims ):
129129 if dim == target_var_idx :
130130 continue
@@ -154,11 +154,13 @@ def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = Tru
154154 for factor_idx , factor in enumerate (bp .factors ):
155155 # Get incoming messages from variables to this factor
156156 incoming_messages = []
157+ var_factor_positions = []
157158 for var in factor .vars :
158159 var_idx_0based = var - 1
159160 # Find position of this factor in v2t[var_idx_0based]
160161 factor_pos = bp .v2t [var_idx_0based ].index (factor_idx )
161162 incoming_messages .append (state .message_out [var_idx_0based ][factor_pos ])
163+ var_factor_positions .append (factor_pos )
162164
163165 # Compute outgoing message to each variable
164166 for var_pos , var in enumerate (factor .vars ):
@@ -177,7 +179,7 @@ def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = Tru
177179 outgoing_msg = outgoing_msg / msg_sum
178180
179181 # Update message_in
180- factor_pos = bp . v2t [ var_idx_0based ]. index ( factor_idx )
182+ factor_pos = var_factor_positions [ var_pos ]
181183 state .message_in [var_idx_0based ][factor_pos ] = outgoing_msg
182184
183185
@@ -334,19 +336,17 @@ def apply_evidence(bp: BeliefPropagation, evidence: Dict[int, int]) -> BeliefPro
334336 for var_pos , var in enumerate (factor .vars ):
335337 if var in evidence :
336338 evid_value = evidence [var ]
337- # Create slice that zeros out non-evidence values
338- slices = [slice (None )] * len (factor .vars )
339- slices [var_pos ] = evid_value
340-
341- # Zero out all non-evidence assignments
342- mask = torch .ones_like (factor_tensor )
343- for i in range (factor_tensor .shape [var_pos ]):
344- if i != evid_value :
345- slices_mask = slices .copy ()
346- slices_mask [var_pos ] = i
347- mask [tuple (slices_mask )] = 0
348-
349- factor_tensor = factor_tensor * mask
339+ dim_size = factor_tensor .shape [var_pos ]
340+ if 0 <= evid_value < dim_size :
341+ all_indices = torch .arange (dim_size , device = factor_tensor .device )
342+ zero_indices = all_indices [all_indices != evid_value ]
343+ if zero_indices .numel () > 0 :
344+ factor_tensor = factor_tensor .index_fill (
345+ var_pos , zero_indices , 0
346+ )
347+ else :
348+ factor_tensor = torch .zeros_like (factor_tensor )
349+ break
350350
351351 new_factors .append (Factor (factor .vars , factor_tensor ))
352352
0 commit comments