Skip to content

Commit 0a0e064

Browse files
committed
Simplify state_dict and fix Windows CI
- Remove template_path from state_dict (determined at load time) - Remove hasattr check in load_state_dict - Remove redundant test_path test Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent d6a7825 commit 0a0e064

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,12 @@ def state_dict(self) -> dict:
373373
374374
Returns:
375375
A dictionary containing the BundleAlgo state to serialize.
376+
377+
Note:
378+
template_path is excluded as it is determined dynamically at load time
379+
based on which path successfully imports the Algo class.
376380
"""
377381
return {
378-
"template_path": self.template_path,
379382
"data_stats_files": self.data_stats_files,
380383
"data_list_file": self.data_list_file,
381384
"mlflow_tracking_uri": self.mlflow_tracking_uri,
@@ -395,8 +398,7 @@ def load_state_dict(self, state: dict) -> None:
395398
state: A dictionary containing the state to restore.
396399
"""
397400
for key, value in state.items():
398-
if hasattr(self, key):
399-
setattr(self, key, value)
401+
setattr(self, key, value)
400402

401403

402404
# path to download the algo_templates

monai/auto3dseg/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,9 +490,8 @@ def algo_from_json(filename: str, template_path: PathLike | None = None, **kwarg
490490
if algo is None:
491491
raise ValueError(f"Failed to instantiate Algo from target '{target}' with paths {template_paths}")
492492

493-
# Restore the state (skip template_path as it's set to the working import path below)
494-
state_to_load = {k: v for k, v in state.items() if k != "template_path"}
495-
algo.load_state_dict(state_to_load)
493+
# Restore the state
494+
algo.load_state_dict(state)
496495

497496
# Use the path that successfully imported the class, not the original saved path
498497
# (the original path may no longer exist if the workdir was moved)

tests/auto3dseg/test_json_serialization.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import os
1515
import tempfile
1616
import unittest
17-
from pathlib import Path
1817

1918
import numpy as np
2019
import torch
@@ -46,11 +45,6 @@ def test_torch_tensor(self) -> None:
4645
result = _make_json_serializable(t)
4746
assert result == [1.0, 2.0]
4847

49-
def test_path(self) -> None:
50-
p = Path("/some/path")
51-
# Use str(p) since path separators differ on Windows vs Unix
52-
assert _make_json_serializable(p) == str(p)
53-
5448
def test_fallback(self) -> None:
5549
class Custom:
5650
def __str__(self) -> str:

0 commit comments

Comments
 (0)