Skip to content

Commit 9edac23

Browse files
committed
Address remaining PR review comments: add docstring entries for validation errors and extend tests
- Update SoftclDiceLoss docstring to document TypeError for iter_ type check and ValueError for negative iter_ - Update SoftDiceclDiceLoss docstring to document ValueError for alpha range validation - Add test_invalid_iter_type to verify TypeError when iter_ is not an int - Add test_invalid_iter_value to verify ValueError when iter_ is negative - Add test_invalid_alpha and test_invalid_alpha_negative to verify alpha validation Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 976a307 commit 9edac23

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

monai/losses/cldice.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def __init__(
152152
153153
Raises:
154154
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
155+
TypeError: When ``iter_`` is not an ``int``.
156+
ValueError: When ``iter_`` is a negative integer.
155157
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
156158
Incompatible values.
157159
@@ -296,6 +298,7 @@ def __init__(
296298
297299
Raises:
298300
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
301+
ValueError: When ``alpha`` is not in ``[0, 1]``.
299302
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
300303
Incompatible values.
301304

tests/losses/test_cldice_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def test_invalid_other_act(self):
106106
with self.assertRaises(TypeError):
107107
SoftclDiceLoss(other_act="invalid")
108108

109+
def test_invalid_iter_type(self):
110+
with self.assertRaises(TypeError):
111+
SoftclDiceLoss(iter_=3.0)
112+
113+
def test_invalid_iter_value(self):
114+
with self.assertRaises(ValueError):
115+
SoftclDiceLoss(iter_=-1)
116+
109117

110118
class TestSoftDiceclDiceLoss(unittest.TestCase):
111119
@parameterized.expand(COMBINED_CASES)
@@ -131,6 +139,14 @@ def test_channel_mismatch(self):
131139
with self.assertRaises(ValueError):
132140
loss(torch.ones(2, 3, 8, 8), torch.ones(2, 2, 8, 8))
133141

142+
def test_invalid_alpha(self):
143+
with self.assertRaises(ValueError):
144+
SoftDiceclDiceLoss(alpha=1.5)
145+
146+
def test_invalid_alpha_negative(self):
147+
with self.assertRaises(ValueError):
148+
SoftDiceclDiceLoss(alpha=-0.5)
149+
134150

135151
if __name__ == "__main__":
136152
unittest.main()

0 commit comments

Comments
 (0)