From 00749def587367bb75fe8a06df9f5684174e68ec Mon Sep 17 00:00:00 2001 From: XianghuWang-287 Date: Fri, 27 Feb 2026 13:38:58 -0800 Subject: [PATCH] Strengthen Case 2 unit test: verify full probability distribution Replace test_top_prediction_on_tail with test_tail_high_probability. Now checks tail (atoms 0-9) high probability AND non-tail low probability. Before bugfix (d858ff2): tail_sum=0.3755, FAILS (< 0.5 threshold) After bugfix (5918489): tail_sum=0.7433, PASSES --- tests/test_real_cases.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/test_real_cases.py b/tests/test_real_cases.py index 0407b17..47d0d9e 100644 --- a/tests/test_real_cases.py +++ b/tests/test_real_cases.py @@ -136,17 +136,29 @@ def test_generate_probabilities(self): self.assertAlmostEqual(np.sum(scores), 1.0, places=5) self.assertTrue(np.all(scores >= 0)) - def test_top_prediction_on_tail(self): + def test_tail_high_probability(self): """The modification is on the alkyl tail (atoms 0-9). - The top predicted atom should be in that region.""" + Tail atoms should have high probability and non-tail atoms should have low probability.""" scores = self.mf.generate_probabilities() - predicted_site = int(np.argmax(scores)) tail_atoms = list(range(0, 10)) - self.assertIn( - predicted_site, - tail_atoms, - f"Predicted site {predicted_site} is not on the alkyl tail (atoms 0-9)", - ) + non_tail_atoms = list(range(10, len(scores))) + + tail_sum = sum(scores[i] for i in tail_atoms) + tail_mean = np.mean([scores[i] for i in tail_atoms]) + non_tail_mean = np.mean([scores[i] for i in non_tail_atoms]) + + # Tail should hold the majority of probability mass + self.assertGreater(tail_sum, 0.5, + f"Tail probability sum {tail_sum:.4f} should be > 0.5") + + # Tail mean should be significantly higher than non-tail mean + self.assertGreater(tail_mean, non_tail_mean * 3, + f"Tail mean {tail_mean:.4f} should be much higher than non-tail mean {non_tail_mean:.4f}") + + # Every non-tail atom should have low probability + for i in non_tail_atoms: + self.assertLess(scores[i], 0.03, + f"Non-tail atom {i} has unexpectedly high probability {scores[i]:.4f}") def test_get_edge_detail_no_side_effect(self): """Previously, calling get_edge_detail should not mutate the original edge matches.