Skip to content

Commit e5a1876

Browse files
authored
1132-fix-the-bug-in-loading-tool-segmentation-model (#1135)
Signed-off-by: binliu <binliu@nvidia.com> Fixes #1132 . ### Description The pretrained `endoscopic_tool_segmentation` model in MONAI model zoo cannot be loaded to `video_seg` tutorial since the different output. Therefore, the whole model except segmentation head should be loaded to perform pretrain. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Notebook runs automatically `./runner [-p <regex_pattern>]` Signed-off-by: binliu <binliu@nvidia.com>
1 parent cbeb69c commit e5a1876

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

computer_assisted_intervention/video_seg.ipynb

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,11 @@
372372
"## Normal training loop\n",
373373
"\n",
374374
"A normal training loop will be excuted. Meanwhile, the iou value of each epoch will be recorded as the accuracy. \n",
375-
"There is a pretrained bundle for surgical tool segmentation task in the MONAI model-zoo, the model can be automatically downloaded and loaded to the network by adding codes to the block below:\n",
375+
"There is a pretrained bundle for surgical tool segmentation task in the MONAI model-zoo, the model can be automatically downloaded and loaded to the network by adding codes to the block below. Please add `from monai.networks.utils import copy_model_state` in the `Imports` part of this tutorial, when using it.\n",
376376
"```\n",
377-
"pretrained_weights = monai.bundle.load(\n",
378-
" name=\"endoscopic_tool_segmentation\", bundle_dir=\"./\", version=\"0.2.0\"\n",
379-
")\n",
380-
"model.load_state_dict(pretrained_weights)\n",
377+
"pretrained_weights = monai.bundle.load(name=\"endoscopic_tool_segmentation\", bundle_dir=\"./\", version=\"0.2.0\")\n",
378+
"new_model_dict, _, _ = copy_model_state(model, pretrained_weights, exclude_vars=\"segmentation_head\")\n",
379+
"model.load_state_dict(new_model_dict)\n",
381380
"```"
382381
]
383382
},
@@ -691,9 +690,9 @@
691690
],
692691
"metadata": {
693692
"kernelspec": {
694-
"display_name": "conda_base",
693+
"display_name": "Python 3 (ipykernel)",
695694
"language": "python",
696-
"name": "base"
695+
"name": "python3"
697696
},
698697
"language_info": {
699698
"codemirror_mode": {

0 commit comments

Comments
 (0)