Skip to content

Commit fc89e7f

Browse files
Apply BP fixes and update tests
1 parent 0c0ed47 commit fc89e7f

8 files changed

Lines changed: 64 additions & 95 deletions

File tree

src/bpdecoderplus/pytorch_bp/belief_propagation.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Belief Propagation (BP) algorithm implementation using PyTorch.
33
"""
44

5-
from typing import List, Dict, Tuple, Optional
5+
from typing import List, Dict, Tuple
66
import torch
77
from copy import deepcopy
88

@@ -88,7 +88,7 @@ def initial_state(bp: BeliefPropagation) -> BPState:
8888
var_messages_in = []
8989
var_messages_out = []
9090

91-
for factor_idx in bp.v2t[var_idx]:
91+
for _ in bp.v2t[var_idx]:
9292
card = bp.cards[var_idx]
9393
msg = torch.ones(card, dtype=torch.float64)
9494
var_messages_in.append(msg.clone())
@@ -97,7 +97,7 @@ def initial_state(bp: BeliefPropagation) -> BPState:
9797
message_in.append(var_messages_in)
9898
message_out.append(var_messages_out)
9999

100-
return BPState(deepcopy(message_in), message_out)
100+
return BPState(message_in, message_out)
101101

102102

103103
def _compute_factor_to_var_message(
@@ -124,7 +124,7 @@ def _compute_factor_to_var_message(
124124
return factor_tensor.clone()
125125

126126
# Multiply factor tensor by incoming messages (excluding target) and sum out dims.
127-
result = factor_tensor
127+
result = factor_tensor.clone()
128128
for dim in range(ndims):
129129
if dim == target_var_idx:
130130
continue
@@ -154,11 +154,13 @@ def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = Tru
154154
for factor_idx, factor in enumerate(bp.factors):
155155
# Get incoming messages from variables to this factor
156156
incoming_messages = []
157+
var_factor_positions = []
157158
for var in factor.vars:
158159
var_idx_0based = var - 1
159160
# Find position of this factor in v2t[var_idx_0based]
160161
factor_pos = bp.v2t[var_idx_0based].index(factor_idx)
161162
incoming_messages.append(state.message_out[var_idx_0based][factor_pos])
163+
var_factor_positions.append(factor_pos)
162164

163165
# Compute outgoing message to each variable
164166
for var_pos, var in enumerate(factor.vars):
@@ -177,7 +179,7 @@ def collect_message(bp: BeliefPropagation, state: BPState, normalize: bool = Tru
177179
outgoing_msg = outgoing_msg / msg_sum
178180

179181
# Update message_in
180-
factor_pos = bp.v2t[var_idx_0based].index(factor_idx)
182+
factor_pos = var_factor_positions[var_pos]
181183
state.message_in[var_idx_0based][factor_pos] = outgoing_msg
182184

183185

@@ -334,19 +336,17 @@ def apply_evidence(bp: BeliefPropagation, evidence: Dict[int, int]) -> BeliefPro
334336
for var_pos, var in enumerate(factor.vars):
335337
if var in evidence:
336338
evid_value = evidence[var]
337-
# Create slice that zeros out non-evidence values
338-
slices = [slice(None)] * len(factor.vars)
339-
slices[var_pos] = evid_value
340-
341-
# Zero out all non-evidence assignments
342-
mask = torch.ones_like(factor_tensor)
343-
for i in range(factor_tensor.shape[var_pos]):
344-
if i != evid_value:
345-
slices_mask = slices.copy()
346-
slices_mask[var_pos] = i
347-
mask[tuple(slices_mask)] = 0
348-
349-
factor_tensor = factor_tensor * mask
339+
dim_size = factor_tensor.shape[var_pos]
340+
if 0 <= evid_value < dim_size:
341+
all_indices = torch.arange(dim_size, device=factor_tensor.device)
342+
zero_indices = all_indices[all_indices != evid_value]
343+
if zero_indices.numel() > 0:
344+
factor_tensor = factor_tensor.index_fill(
345+
var_pos, zero_indices, 0
346+
)
347+
else:
348+
factor_tensor = torch.zeros_like(factor_tensor)
349+
break
350350

351351
new_factors.append(Factor(factor.vars, factor_tensor))
352352

src/bpdecoderplus/pytorch_bp/uai_parser.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
UAI file format parser for Belief Propagation.
33
"""
44

5-
from typing import List, Dict, Tuple
5+
from typing import List, Dict
66
import torch
77

88

@@ -71,6 +71,10 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode
7171

7272
# Parse header
7373
network_type = lines[0] # MARKOV or BAYES
74+
if network_type not in ("MARKOV", "BAYES"):
75+
raise ValueError(
76+
f"Unsupported UAI network type: {network_type!r}. Expected 'MARKOV' or 'BAYES'."
77+
)
7478
nvars = int(lines[1])
7579
cards = [int(x) for x in lines[2].split()]
7680
ntables = int(lines[3])
@@ -80,6 +84,11 @@ def read_model_from_string(content: str, factor_eltype=torch.float64) -> UAIMode
8084
for i in range(ntables):
8185
parts = lines[4 + i].split()
8286
scope_size = int(parts[0])
87+
if len(parts) - 1 != scope_size:
88+
raise ValueError(
89+
f"Scope size mismatch on line {4 + i}: "
90+
f"declared {scope_size}, found {len(parts) - 1} variables."
91+
)
8392
scope = [int(x) + 1 for x in parts[1:]] # Convert to 1-based
8493
scopes.append(scope)
8594

src/bpdecoderplus/pytorch_bp/utils.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

tests/_path.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44

55
def add_project_root_to_path():
66
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
7-
if project_root not in sys.path:
8-
sys.path.insert(0, project_root)
7+
src_root = os.path.join(project_root, "src")
8+
for path in (src_root, project_root):
9+
if path not in sys.path:
10+
sys.path.insert(0, path)

tests/test_bp_basic.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import unittest
2-
import itertools
32
import torch
43

4+
try:
5+
from ._path import add_project_root_to_path
6+
except ImportError:
7+
from _path import add_project_root_to_path
8+
9+
add_project_root_to_path()
10+
511
from bpdecoderplus.pytorch_bp import (
612
read_model_from_string,
713
BeliefPropagation,
@@ -10,32 +16,7 @@
1016
apply_evidence,
1117
)
1218

13-
14-
def exact_marginals(model, evidence=None):
15-
evidence = evidence or {}
16-
assignments = list(itertools.product(*[range(c) for c in model.cards]))
17-
weights = []
18-
for assignment in assignments:
19-
if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()):
20-
weights.append(0.0)
21-
continue
22-
weight = 1.0
23-
for factor in model.factors:
24-
idx = tuple(assignment[v - 1] for v in factor.vars)
25-
weight *= float(factor.values[idx])
26-
weights.append(weight)
27-
total = sum(weights)
28-
marginals = {}
29-
for var_idx, card in enumerate(model.cards):
30-
values = []
31-
for value in range(card):
32-
mass = 0.0
33-
for assignment, weight in zip(assignments, weights):
34-
if assignment[var_idx] == value:
35-
mass += weight
36-
values.append(mass / total if total > 0 else 0.0)
37-
marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64)
38-
return marginals
19+
from tests.test_utils import exact_marginals
3920

4021

4122
class TestBeliefPropagationBasic(unittest.TestCase):

tests/test_integration.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
import unittest
22

3+
try:
4+
from ._path import add_project_root_to_path
5+
except ImportError:
6+
from _path import add_project_root_to_path
7+
8+
add_project_root_to_path()
9+
310
from bpdecoderplus.pytorch_bp import (
411
read_model_file,
512
read_evidence_file,
@@ -15,7 +22,7 @@ def test_example_file_runs(self):
1522
model = read_model_file("examples/simple_model.uai")
1623
bp = BeliefPropagation(model)
1724
state, info = belief_propagate(bp, max_iter=30, tol=1e-8)
18-
self.assertTrue(info.iterations > 0)
25+
self.assertGreater(info.iterations, 0)
1926
marginals = compute_marginals(state, bp)
2027
self.assertEqual(set(marginals.keys()), {1, 2})
2128

@@ -24,7 +31,7 @@ def test_example_with_evidence(self):
2431
evidence = read_evidence_file("examples/simple_model.evid")
2532
bp = apply_evidence(BeliefPropagation(model), evidence)
2633
state, info = belief_propagate(bp, max_iter=30, tol=1e-8)
27-
self.assertTrue(info.iterations > 0)
34+
self.assertGreater(info.iterations, 0)
2835
marginals = compute_marginals(state, bp)
2936
self.assertEqual(set(marginals.keys()), {1, 2})
3037

tests/test_uai_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import unittest
22
import torch
33

4+
try:
5+
from ._path import add_project_root_to_path
6+
except ImportError:
7+
from _path import add_project_root_to_path
8+
9+
add_project_root_to_path()
10+
411
from bpdecoderplus.pytorch_bp import read_model_from_string, read_evidence_file
512

613

tests/testcase.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import unittest
2-
import itertools
32
import torch
43

4+
try:
5+
from ._path import add_project_root_to_path
6+
except ImportError:
7+
from _path import add_project_root_to_path
8+
9+
add_project_root_to_path()
10+
511
from bpdecoderplus.pytorch_bp import (
612
read_model_from_string,
713
BeliefPropagation,
@@ -13,31 +19,7 @@
1319
)
1420

1521

16-
def exact_marginals(model, evidence=None):
17-
evidence = evidence or {}
18-
assignments = list(itertools.product(*[range(c) for c in model.cards]))
19-
weights = []
20-
for assignment in assignments:
21-
if any(assignment[var_idx - 1] != val for var_idx, val in evidence.items()):
22-
weights.append(0.0)
23-
continue
24-
weight = 1.0
25-
for factor in model.factors:
26-
idx = tuple(assignment[v - 1] for v in factor.vars)
27-
weight *= float(factor.values[idx])
28-
weights.append(weight)
29-
total = sum(weights)
30-
marginals = {}
31-
for var_idx, card in enumerate(model.cards):
32-
values = []
33-
for value in range(card):
34-
mass = 0.0
35-
for assignment, weight in zip(assignments, weights):
36-
if assignment[var_idx] == value:
37-
mass += weight
38-
values.append(mass / total if total > 0 else 0.0)
39-
marginals[var_idx + 1] = torch.tensor(values, dtype=torch.float64)
40-
return marginals
22+
from tests.test_utils import exact_marginals
4123

4224

4325
class TestBPAdditionalCases(unittest.TestCase):

0 commit comments

Comments
 (0)