Skip to content

Commit b43c95d

Browse files
committed
Add state_dict/load_state_dict to Algo classes and fix CI issues
- Add state_dict() and load_state_dict() methods to Algo base class - Override in BundleAlgo with serializable attributes - Update algo_to_json/algo_from_json to use state_dict pattern - Fix Black formatting in utils.py - Fix Windows path test for cross-platform compatibility Addresses reviewer feedback from @ericspod to follow PyTorch conventions.
1 parent 7f40e55 commit b43c95d

File tree

4 files changed

+61
-17
lines changed

4 files changed

+61
-17
lines changed

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,37 @@ def get_output_path(self):
367367
"""Returns the algo output paths to find the algo scripts and configs."""
368368
return self.output_path
369369

370+
def state_dict(self) -> dict:
371+
"""
372+
Return state for serialization.
373+
374+
Returns:
375+
A dictionary containing the BundleAlgo state to serialize.
376+
"""
377+
return {
378+
"template_path": self.template_path,
379+
"data_stats_files": self.data_stats_files,
380+
"data_list_file": self.data_list_file,
381+
"mlflow_tracking_uri": self.mlflow_tracking_uri,
382+
"mlflow_experiment_name": self.mlflow_experiment_name,
383+
"output_path": self.output_path,
384+
"name": self.name,
385+
"best_metric": self.best_metric,
386+
"fill_records": self.fill_records,
387+
"device_setting": self.device_setting,
388+
}
389+
390+
def load_state_dict(self, state: dict) -> None:
391+
"""
392+
Restore state from a dictionary.
393+
394+
Args:
395+
state: A dictionary containing the state to restore.
396+
"""
397+
for key, value in state.items():
398+
if hasattr(self, key):
399+
setattr(self, key, value)
400+
370401

371402
# path to download the algo_templates
372403
default_algo_zip = (

monai/auto3dseg/algo_gen.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,31 @@ def get_output_path(self, *args, **kwargs):
4343
"""Returns the algo output paths for scripts location"""
4444
pass
4545

46+
def state_dict(self) -> dict:
47+
"""
48+
Return state for serialization.
49+
50+
Subclasses should override this method to return a dictionary of
51+
attributes that need to be serialized. This follows the PyTorch
52+
convention for state management.
53+
54+
Returns:
55+
A dictionary containing the state to serialize.
56+
"""
57+
return {}
58+
59+
def load_state_dict(self, state: dict) -> None:
60+
"""
61+
Restore state from a dictionary.
62+
63+
Subclasses should override this method to restore their state
64+
from the dictionary returned by state_dict().
65+
66+
Args:
67+
state: A dictionary containing the state to restore.
68+
"""
69+
pass
70+
4671

4772
class AlgoGen(Randomizable):
4873
"""

monai/auto3dseg/utils.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -322,19 +322,7 @@ def algo_to_json(algo: Algo, template_path: PathLike | None = None, **algo_meta_
322322
Returns:
323323
Filename of the saved Algo object (algo_object.json).
324324
"""
325-
attrs = [
326-
"template_path",
327-
"data_stats_files",
328-
"data_list_file",
329-
"mlflow_tracking_uri",
330-
"mlflow_experiment_name",
331-
"output_path",
332-
"name",
333-
"best_metric",
334-
"fill_records",
335-
"device_setting"
336-
]
337-
state = {a : _make_json_serializable(getattr(algo, a)) for a in attrs if hasattr(algo, a)}
325+
state = {k: _make_json_serializable(v) for k, v in algo.state_dict().items()}
338326

339327
# Build target string for dynamic class instantiation
340328
cls = algo.__class__
@@ -503,9 +491,8 @@ def algo_from_json(filename: str, template_path: PathLike | None = None, **kwarg
503491
raise ValueError(f"Failed to instantiate Algo from target '{target}' with paths {template_paths}")
504492

505493
# Restore the state (skip template_path as it's set to the working import path below)
506-
for attr, value in state.items():
507-
if attr != "template_path" and hasattr(algo, attr):
508-
setattr(algo, attr, value)
494+
state_to_load = {k: v for k, v in state.items() if k != "template_path"}
495+
algo.load_state_dict(state_to_load)
509496

510497
# Use the path that successfully imported the class, not the original saved path
511498
# (the original path may no longer exist if the workdir was moved)

tests/auto3dseg/test_json_serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def test_torch_tensor(self) -> None:
4848

4949
def test_path(self) -> None:
5050
p = Path("/some/path")
51-
assert _make_json_serializable(p) == "/some/path"
51+
# Use str(p) since path separators differ on Windows vs Unix
52+
assert _make_json_serializable(p) == str(p)
5253

5354
def test_fallback(self) -> None:
5455
class Custom:

0 commit comments

Comments
 (0)