Skip to content

Commit 6b3f4e2

Browse files
Add tests for BP edge cases and UAI errors
1 parent fc89e7f commit 6b3f4e2

2 files changed

Lines changed: 80 additions & 0 deletions

File tree

tests/test_uai_parser.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,46 @@ def test_read_evidence_file(self):
5858
evidence = read_evidence_file("examples/simple_model.evid")
5959
self.assertEqual(evidence, {1: 1})
6060

61+
def test_invalid_network_type(self):
62+
content = "\n".join(
63+
[
64+
"INVALID",
65+
"1",
66+
"2",
67+
"0",
68+
]
69+
)
70+
with self.assertRaises(ValueError):
71+
read_model_from_string(content)
72+
73+
def test_scope_size_mismatch(self):
74+
content = "\n".join(
75+
[
76+
"MARKOV",
77+
"2",
78+
"2 2",
79+
"1",
80+
"2 0",
81+
"2",
82+
"0.5 0.5",
83+
]
84+
)
85+
with self.assertRaises(ValueError):
86+
read_model_from_string(content)
87+
88+
def test_missing_table_entries(self):
89+
content = "\n".join(
90+
[
91+
"MARKOV",
92+
"1",
93+
"2",
94+
"1",
95+
"1 0",
96+
]
97+
)
98+
with self.assertRaises(ValueError):
99+
read_model_from_string(content)
100+
61101

62102
if __name__ == "__main__":
63103
unittest.main()

tests/testcase.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
initial_state,
1717
collect_message,
1818
process_message,
19+
apply_evidence,
1920
)
2021

2122

@@ -91,6 +92,45 @@ def test_message_normalization(self):
9192
for msg in var_msgs:
9293
self.assertAlmostEqual(float(msg.sum()), 1.0, places=6)
9394

95+
def test_zero_message_handling(self):
96+
content = "\n".join(
97+
[
98+
"MARKOV",
99+
"1",
100+
"2",
101+
"2",
102+
"1 0",
103+
"1 0",
104+
"2",
105+
"0.0 0.0",
106+
"2",
107+
"0.7 0.3",
108+
]
109+
)
110+
model = read_model_from_string(content)
111+
bp = BeliefPropagation(model)
112+
state = initial_state(bp)
113+
collect_message(bp, state, normalize=True)
114+
process_message(bp, state, normalize=True, damping=0.0)
115+
self.assertAlmostEqual(float(state.message_in[0][0].sum()), 0.0, places=6)
116+
self.assertAlmostEqual(float(state.message_out[0][1].sum()), 0.0, places=6)
117+
118+
def test_evidence_out_of_range_zeros_factor(self):
119+
content = "\n".join(
120+
[
121+
"MARKOV",
122+
"1",
123+
"2",
124+
"1",
125+
"1 0",
126+
"2",
127+
"0.4 0.6",
128+
]
129+
)
130+
model = read_model_from_string(content)
131+
bp = apply_evidence(BeliefPropagation(model), {1: 5})
132+
self.assertAlmostEqual(float(bp.factors[0].values.sum()), 0.0, places=6)
133+
94134

95135
if __name__ == "__main__":
96136
unittest.main()

0 commit comments

Comments
 (0)