From 588b299f5b2042b2739256ab35e7de8cbb854e1e Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Fri, 14 Mar 2025 09:16:50 -0400 Subject: [PATCH 1/4] Update readme Signed-off-by: heyufan1995 --- vista3d/cvpr_workshop/README.md | 2 -- vista3d/cvpr_workshop/train_cvpr.py | 5 +++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md index 6769fb9..fc9e675 100644 --- a/vista3d/cvpr_workshop/README.md +++ b/vista3d/cvpr_workshop/README.md @@ -15,7 +15,6 @@ limitations under the License. This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI. - It is simplified to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. The finetuned VISTA3D checkpoint on the challenge subsets is available [here](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) # Setup @@ -38,4 +37,3 @@ docker save -o vista3d.tar.gz vista3d:latest ``` - diff --git a/vista3d/cvpr_workshop/train_cvpr.py b/vista3d/cvpr_workshop/train_cvpr.py index 25a33bc..b1973a6 100755 --- a/vista3d/cvpr_workshop/train_cvpr.py +++ b/vista3d/cvpr_workshop/train_cvpr.py @@ -104,12 +104,15 @@ def __getitem__(self, idx): return data # Training function def train(): + json_file = "subset.json" # Update with your JSON file json_file = "subset.json" # Update with your JSON file epoch_number = 100 start_epoch = 0 lr = 2e-5 checkpoint_dir = "checkpoints" start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth' + start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth' + os.makedirs(checkpoint_dir, exist_ok=True) dist.init_process_group(backend="nccl") world_size = int(os.environ["WORLD_SIZE"]) @@ -122,6 +125,8 @@ def train(): model = vista3d132(in_channels=1).to(device) pretrained_ckpt = torch.load(start_checkpoint, map_location=device) # pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) + pretrained_ckpt = torch.load(start_checkpoint, map_location=device) + # pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) model.load_state_dict(pretrained_ckpt['model'], strict=True) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05) From 95187baa0627adf8d39ed192d099417903324a13 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 18 Mar 2025 12:28:02 -0400 Subject: [PATCH 2/4] Fix finetuned checkpoint path Signed-off-by: heyufan1995 --- vista3d/cvpr_workshop/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md index fc9e675..4a61a6c 100644 --- a/vista3d/cvpr_workshop/README.md +++ b/vista3d/cvpr_workshop/README.md @@ -23,7 +23,7 @@ pip install -r requirements.txt ``` # Training -Download the challenge subsets finetuned [checkpoint](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) or VISTA3D original [checkpoint]((https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)). Generate a json list that contains your traning data and update the json file path in the script. +Download the challenge subsets finetuned [checkpoint](https://drive.google.com/file/d/1hQ8imaf4nNSg_43dYbPSJT0dr7JgAKWX/view?usp=sharing) or VISTA3D original [checkpoint]((https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)). Generate a json list that contains your traning data and update the json file path in the script. ``` torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py ``` @@ -35,5 +35,6 @@ We provide a Dockerfile to satisfy the challenge format. For more details, refer docker build -t vista3d:latest . docker save -o vista3d.tar.gz vista3d:latest ``` +You can also directly run `predict.sh`. Download the finetuned checkpoint and modify the `--model=/your_downloaded_checkpoint'. From 8427de47ea5eb84945fd34c422d731f18b6cfb8b Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 18 Mar 2025 12:28:18 -0400 Subject: [PATCH 3/4] Update readme Signed-off-by: heyufan1995 --- vista3d/cvpr_workshop/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md index 4a61a6c..626f184 100644 --- a/vista3d/cvpr_workshop/README.md +++ b/vista3d/cvpr_workshop/README.md @@ -27,6 +27,8 @@ Download the challenge subsets finetuned [checkpoint](https://drive.google.com/f ``` torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py ``` +The checkpoint saved by train_cvpr.py can be updated by `update_ckpt.py` to remove the additional `module` key due to multi-gpu training. + # Inference You can directly download the [docker file](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) for the challenge baseline. @@ -35,6 +37,6 @@ We provide a Dockerfile to satisfy the challenge format. For more details, refer docker build -t vista3d:latest . docker save -o vista3d.tar.gz vista3d:latest ``` -You can also directly run `predict.sh`. Download the finetuned checkpoint and modify the `--model=/your_downloaded_checkpoint'. +You can also directly run `predict.sh`. Download the finetuned checkpoint and modify the `--model=/your_downloaded_checkpoint`. Change `save_data=True` in `infer_cvpr.py` to save predictions to nifti files for visualization. From f2c4fe1436890e5ffd2e6b41f5bc90d3166573c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Mar 2025 16:35:54 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- vista3d/cvpr_workshop/Dockerfile | 2 +- vista3d/cvpr_workshop/README.md | 8 +- vista3d/cvpr_workshop/infer_cvpr.py | 166 ++++++++++++----------- vista3d/cvpr_workshop/requirements.txt | 2 +- vista3d/cvpr_workshop/train_cvpr.py | 179 ++++++++++++++++--------- vista3d/cvpr_workshop/update_ckpt.py | 23 +++- 6 files changed, 224 insertions(+), 156 deletions(-) diff --git a/vista3d/cvpr_workshop/Dockerfile b/vista3d/cvpr_workshop/Dockerfile index 186d69f..418ef7d 100755 --- a/vista3d/cvpr_workshop/Dockerfile +++ b/vista3d/cvpr_workshop/Dockerfile @@ -21,4 +21,4 @@ COPY predict.sh /workspace/predict.sh RUN chmod +x /workspace/predict.sh # Set default command -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"] diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md index 626f184..a38091e 100644 --- a/vista3d/cvpr_workshop/README.md +++ b/vista3d/cvpr_workshop/README.md @@ -12,7 +12,7 @@ limitations under the License. --> # Overview -This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It +This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI. It is simplified to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. The finetuned VISTA3D checkpoint on the challenge subsets is available [here](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) @@ -27,7 +27,7 @@ Download the challenge subsets finetuned [checkpoint](https://drive.google.com/f ``` torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py ``` -The checkpoint saved by train_cvpr.py can be updated by `update_ckpt.py` to remove the additional `module` key due to multi-gpu training. +The checkpoint saved by train_cvpr.py can be updated by `update_ckpt.py` to remove the additional `module` key due to multi-gpu training. # Inference @@ -37,6 +37,4 @@ We provide a Dockerfile to satisfy the challenge format. For more details, refer docker build -t vista3d:latest . docker save -o vista3d.tar.gz vista3d:latest ``` -You can also directly run `predict.sh`. Download the finetuned checkpoint and modify the `--model=/your_downloaded_checkpoint`. Change `save_data=True` in `infer_cvpr.py` to save predictions to nifti files for visualization. - - +You can also directly run `predict.sh`. Download the finetuned checkpoint and modify the `--model=/your_downloaded_checkpoint`. Change `save_data=True` in `infer_cvpr.py` to save predictions to nifti files for visualization. diff --git a/vista3d/cvpr_workshop/infer_cvpr.py b/vista3d/cvpr_workshop/infer_cvpr.py index 7694965..dddf7f2 100755 --- a/vista3d/cvpr_workshop/infer_cvpr.py +++ b/vista3d/cvpr_workshop/infer_cvpr.py @@ -1,147 +1,159 @@ +import argparse +import glob + import monai import monai.transforms -import torch -import argparse -import numpy as np import nibabel as nib -import glob -from monai.networks.nets.vista3d import vista3d132 -from monai.utils import optional_import +import numpy as np +import torch from monai.apps.vista3d.inferer import point_based_window_inferer from monai.inferers import SlidingWindowInfererAdapt +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import tqdm, _ = optional_import("tqdm", name="tqdm") -import numpy as np -import pdb import os + + def convert_clicks(alldata): # indexes = list(alldata.keys()) # data = [alldata[i] for i in indexes] data = alldata B = len(data) # Number of objects - indexes = np.arange(1, B+1).tolist() + indexes = np.arange(1, B + 1).tolist() # Determine the maximum number of points across all objects - max_N = max(len(obj['fg']) + len(obj['bg']) for obj in data) - + max_N = max(len(obj["fg"]) + len(obj["bg"]) for obj in data) + # Initialize padded arrays point_coords = np.zeros((B, max_N, 3), dtype=int) point_labels = np.full((B, max_N), -1, dtype=int) - + for i, obj in enumerate(data): points = [] labels = [] - + # Add foreground points - for fg_point in obj['fg']: + for fg_point in obj["fg"]: points.append(fg_point) labels.append(1) - + # Add background points - for bg_point in obj['bg']: + for bg_point in obj["bg"]: points.append(bg_point) labels.append(0) - + # Fill in the arrays - point_coords[i, :len(points)] = points - point_labels[i, :len(labels)] = labels - + point_coords[i, : len(points)] = points + point_labels[i, : len(labels)] = labels + return point_coords, point_labels, indexes -if __name__ == '__main__': +if __name__ == "__main__": # set to true to save nifti files for visualization save_data = False - point_inferer = True # use point based inferen - roi_size = [128,128,128] + point_inferer = True # use point based inferen + roi_size = [128, 128, 128] parser = argparse.ArgumentParser() - parser.add_argument("--test_img_path", type=str, default='./tests') - parser.add_argument("--save_path", type=str, default='./outputs/') - parser.add_argument("--model", type=str, default='checkpoints/model_final.pth') + parser.add_argument("--test_img_path", type=str, default="./tests") + parser.add_argument("--save_path", type=str, default="./outputs/") + parser.add_argument("--model", type=str, default="checkpoints/model_final.pth") args = parser.parse_args() - os.makedirs(args.save_path,exist_ok=True) + os.makedirs(args.save_path, exist_ok=True) # load model checkpoint_path = args.model model = vista3d132(in_channels=1) - pretrained_ckpt = torch.load(checkpoint_path, map_location='cuda') + pretrained_ckpt = torch.load(checkpoint_path, map_location="cuda") model.load_state_dict(pretrained_ckpt, strict=True) - + # load data test_cases = glob.glob(os.path.join(args.test_img_path, "*.npz")) for img_path in test_cases: case_name = os.path.basename(img_path) print(case_name) img = np.load(img_path, allow_pickle=True) - img_array = img['imgs'] - spacing = img['spacing'] + img_array = img["imgs"] + spacing = img["spacing"] original_shape = img_array.shape affine = np.diag(spacing.tolist() + [1]) # 4x4 affine matrix if save_data: # Create a NIfTI image nifti_img = nib.Nifti1Image(img_array, affine) # Save the NIfTI file - nib.save(nifti_img, img_path.replace('.npz','.nii.gz')) - nifti_img = nib.Nifti1Image(img['gts'], affine) + nib.save(nifti_img, img_path.replace(".npz", ".nii.gz")) + nifti_img = nib.Nifti1Image(img["gts"], affine) # Save the NIfTI file - nib.save(nifti_img, img_path.replace('.npz','gts.nii.gz')) - clicks = img.get('clicks', [{'fg':[[418, 138, 136]], 'bg':[]}]) + nib.save(nifti_img, img_path.replace(".npz", "gts.nii.gz")) + clicks = img.get("clicks", [{"fg": [[418, 138, 136]], "bg": []}]) point_coords, point_labels, indexes = convert_clicks(clicks) # preprocess img_array = torch.from_numpy(img_array) - img_array = img_array.unsqueeze(0) - img_array = monai.transforms.ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0, b_max=1, clip=True)(img_array) - img_array = img_array.unsqueeze(0) # add channel dim - device = 'cuda' + img_array = img_array.unsqueeze(0) + img_array = monai.transforms.ScaleIntensityRangePercentiles( + lower=1, upper=99, b_min=0, b_max=1, clip=True + )(img_array) + img_array = img_array.unsqueeze(0) # add channel dim + device = "cuda" # slidingwindow with torch.no_grad(): if not point_inferer: - model.NINF_VALUE = 0 # set to 0 in case sliding window is used. + model.NINF_VALUE = 0 # set to 0 in case sliding window is used. # directly using slidingwindow inferer is not optimal. - val_outputs = SlidingWindowInfererAdapt( - roi_size=roi_size, sw_batch_size=1, with_coord=True, padding_mode="replicate" - )( - inputs=img_array.to(device), - transpose=True, - network=model.to(device), - point_coords=torch.from_numpy(point_coords).to(device), - point_labels=torch.from_numpy(point_labels).to(device) - )[0] > 0 + val_outputs = ( + SlidingWindowInfererAdapt( + roi_size=roi_size, + sw_batch_size=1, + with_coord=True, + padding_mode="replicate", + )( + inputs=img_array.to(device), + transpose=True, + network=model.to(device), + point_coords=torch.from_numpy(point_coords).to(device), + point_labels=torch.from_numpy(point_labels).to(device), + )[ + 0 + ] + > 0 + ) final_outputs = torch.zeros_like(val_outputs[0], dtype=torch.float32) for i, v in enumerate(val_outputs): final_outputs += indexes[i] * v else: # point based - final_outputs = torch.zeros_like(img_array[0,0], dtype=torch.float32) + final_outputs = torch.zeros_like(img_array[0, 0], dtype=torch.float32) for i, v in enumerate(indexes): - val_outputs = point_based_window_inferer( - inputs=img_array.to(device), - roi_size=roi_size, - transpose=True, - with_coord=True, - predictor=model.to(device), - mode="gaussian", - sw_device=device, - device=device, - center_only=True, # only crop the center - point_coords=torch.from_numpy(point_coords[[i]]).to(device), - point_labels=torch.from_numpy(point_labels[[i]]).to(device) - )[0] > 0 + val_outputs = ( + point_based_window_inferer( + inputs=img_array.to(device), + roi_size=roi_size, + transpose=True, + with_coord=True, + predictor=model.to(device), + mode="gaussian", + sw_device=device, + device=device, + center_only=True, # only crop the center + point_coords=torch.from_numpy(point_coords[[i]]).to(device), + point_labels=torch.from_numpy(point_labels[[i]]).to(device), + )[0] + > 0 + ) final_outputs[val_outputs[0]] = v final_outputs = torch.nan_to_num(final_outputs) - # save data + # save data if save_data: # Create a NIfTI image - nifti_img = nib.Nifti1Image(final_outputs.to(torch.float32).data.cpu().numpy(), affine) + nifti_img = nib.Nifti1Image( + final_outputs.to(torch.float32).data.cpu().numpy(), affine + ) # Save the NIfTI file - nib.save(nifti_img, os.path.join(args.save_path, case_name.replace('.npz','.nii.gz'))) - np.savez_compressed(os.path.join(args.save_path, case_name), segs=final_outputs.to(torch.float32).data.cpu().numpy()) - - - - - - - - - - \ No newline at end of file + nib.save( + nifti_img, + os.path.join(args.save_path, case_name.replace(".npz", ".nii.gz")), + ) + np.savez_compressed( + os.path.join(args.save_path, case_name), + segs=final_outputs.to(torch.float32).data.cpu().numpy(), + ) diff --git a/vista3d/cvpr_workshop/requirements.txt b/vista3d/cvpr_workshop/requirements.txt index c57fbde..308df32 100755 --- a/vista3d/cvpr_workshop/requirements.txt +++ b/vista3d/cvpr_workshop/requirements.txt @@ -10,4 +10,4 @@ numpy scipy cupy-cuda12x cucim -tqdm \ No newline at end of file +tqdm diff --git a/vista3d/cvpr_workshop/train_cvpr.py b/vista3d/cvpr_workshop/train_cvpr.py index b1973a6..c905e98 100755 --- a/vista3d/cvpr_workshop/train_cvpr.py +++ b/vista3d/cvpr_workshop/train_cvpr.py @@ -1,35 +1,34 @@ -import os import json +import os +import warnings + +import monai import monai.transforms +import numpy as np import torch -import torch.nn as nn -import torch.optim as optim import torch.distributed as dist -from torch.utils.data import Dataset -from torch.nn.parallel import DistributedDataParallel as DDP -import numpy as np -import monai -from tqdm import tqdm -import pdb -from monai.networks.nets import vista3d132 +import torch.optim as optim from monai.apps.vista3d.sampler import sample_prompt_pairs +from monai.data import DataLoader +from monai.networks.nets import vista3d132 +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset from torch.utils.tensorboard import SummaryWriter -from monai.data import DataLoader, DistributedSampler -import warnings -import nibabel as nib +from tqdm import tqdm + warnings.simplefilter("ignore") # Custom dataset for .npz files import matplotlib.pyplot as plt -import torchvision.utils as vutils -NUM_PATCHES_PER_IMAGE=4 +NUM_PATCHES_PER_IMAGE = 4 + def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs): """ Plots B figures, where each figure shows the slice where the point is located and overlays the point on this slice. - + Args: writer: TensorBoard writer epoch: Current epoch number @@ -44,64 +43,96 @@ def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs): for b in range(B): fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - + # Select the first click point in (z, y, x) format x, y, z = points[b, 0].cpu().numpy().astype(int) - + # Extract the corresponding slice input_slice = inputs_np[:, :, z] # Get slice at depth z label_slice = labels_np[:, :, z] output_slice = outputs[b, 0].cpu().detach().numpy()[:, :, z] > 0 - + # Plot input with point overlay - axes[0].imshow(input_slice, cmap='gray') - axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[0].imshow(input_slice, cmap="gray") + axes[0].scatter(y, x, c="red", marker="x", s=50) axes[0].set_title(f"Input (Slice {z})") - + # Plot label - axes[1].imshow(label_slice, cmap='gray') - axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[1].imshow(label_slice, cmap="gray") + axes[0].scatter(y, x, c="red", marker="x", s=50) axes[1].set_title(f"Ground Truth (Slice {z})") - + # Plot output - axes[2].imshow(output_slice, cmap='gray') - axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[2].imshow(output_slice, cmap="gray") + axes[0].scatter(y, x, c="red", marker="x", s=50) axes[2].set_title(f"Model Output (Slice {z})") - + plt.tight_layout() - + # Log figure to TensorBoard writer.add_figure(f"Object_{b}_Segmentation", fig, epoch) plt.close(fig) + class NPZDataset(Dataset): def __init__(self, json_file): - with open(json_file, 'r') as f: + with open(json_file, "r") as f: self.file_paths = json.load(f) - self.base_path = '/workspace/VISTA/CVPR-MedSegFMCompetition/trainsubset' + self.base_path = "/workspace/VISTA/CVPR-MedSegFMCompetition/trainsubset" + def __len__(self): return len(self.file_paths) def __getitem__(self, idx): img = np.load(os.path.join(self.base_path, self.file_paths[idx])) - img_array = torch.from_numpy(img['imgs']).unsqueeze(0).to(torch.float32) - label = torch.from_numpy(img['gts']).unsqueeze(0).to(torch.int32) - data = {"image": img_array, "label": label, 'filename': self.file_paths[idx]} - affine = np.diag(img['spacing'].tolist() + [1]) # 4x4 affine matrix - transforms = monai.transforms.Compose([ - monai.transforms.ScaleIntensityRangePercentilesd(keys="image", lower=1, upper=99, b_min=0, b_max=1, clip=True), - monai.transforms.SpatialPadd(mode=["constant", "constant"], keys=["image", "label"], spatial_size=[128, 128, 128]), - monai.transforms.RandCropByLabelClassesd(spatial_size=[128, 128, 128], keys=["image", "label"], label_key="label",num_classes=label.max() + 1, num_samples=NUM_PATCHES_PER_IMAGE), - monai.transforms.RandScaleIntensityd(factors=0.2, prob=0.2, keys="image"), - monai.transforms.RandShiftIntensityd(offsets=0.2, prob=0.2, keys="image"), - monai.transforms.RandGaussianNoised(mean=0., std=0.2, prob=0.2, keys="image"), - monai.transforms.RandFlipd(spatial_axis=0, prob=0.2, keys=["image", "label"]), - monai.transforms.RandFlipd(spatial_axis=1, prob=0.2, keys=["image", "label"]), - monai.transforms.RandFlipd(spatial_axis=2, prob=0.2, keys=["image", "label"]), - monai.transforms.RandRotate90d(max_k=3, prob=0.2, keys=["image", "label"]) - ]) + img_array = torch.from_numpy(img["imgs"]).unsqueeze(0).to(torch.float32) + label = torch.from_numpy(img["gts"]).unsqueeze(0).to(torch.int32) + data = {"image": img_array, "label": label, "filename": self.file_paths[idx]} + affine = np.diag(img["spacing"].tolist() + [1]) # 4x4 affine matrix + transforms = monai.transforms.Compose( + [ + monai.transforms.ScaleIntensityRangePercentilesd( + keys="image", lower=1, upper=99, b_min=0, b_max=1, clip=True + ), + monai.transforms.SpatialPadd( + mode=["constant", "constant"], + keys=["image", "label"], + spatial_size=[128, 128, 128], + ), + monai.transforms.RandCropByLabelClassesd( + spatial_size=[128, 128, 128], + keys=["image", "label"], + label_key="label", + num_classes=label.max() + 1, + num_samples=NUM_PATCHES_PER_IMAGE, + ), + monai.transforms.RandScaleIntensityd( + factors=0.2, prob=0.2, keys="image" + ), + monai.transforms.RandShiftIntensityd( + offsets=0.2, prob=0.2, keys="image" + ), + monai.transforms.RandGaussianNoised( + mean=0.0, std=0.2, prob=0.2, keys="image" + ), + monai.transforms.RandFlipd( + spatial_axis=0, prob=0.2, keys=["image", "label"] + ), + monai.transforms.RandFlipd( + spatial_axis=1, prob=0.2, keys=["image", "label"] + ), + monai.transforms.RandFlipd( + spatial_axis=2, prob=0.2, keys=["image", "label"] + ), + monai.transforms.RandRotate90d( + max_k=3, prob=0.2, keys=["image", "label"] + ), + ] + ) data = transforms(data) return data + + # Training function def train(): json_file = "subset.json" # Update with your JSON file @@ -110,8 +141,8 @@ def train(): start_epoch = 0 lr = 2e-5 checkpoint_dir = "checkpoints" - start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth' - start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth' + start_checkpoint = "/workspace/CPRR25_vista3D_model_final_10percent_data.pth" + start_checkpoint = "/workspace/CPRR25_vista3D_model_final_10percent_data.pth" os.makedirs(checkpoint_dir, exist_ok=True) dist.init_process_group(backend="nccl") @@ -120,7 +151,9 @@ def train(): torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") dataset = NPZDataset(json_file) - sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=local_rank + ) dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32) model = vista3d132(in_channels=1).to(device) pretrained_ckpt = torch.load(start_checkpoint, map_location=device) @@ -128,9 +161,14 @@ def train(): pretrained_ckpt = torch.load(start_checkpoint, map_location=device) # pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) - model.load_state_dict(pretrained_ckpt['model'], strict=True) + model.load_state_dict(pretrained_ckpt["model"], strict=True) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05) - lr_scheduler = monai.optimizers.WarmupCosineSchedule(optimizer=optimizer, t_total= epoch_number+1, warmup_multiplier=0.1, warmup_steps=0) + lr_scheduler = monai.optimizers.WarmupCosineSchedule( + optimizer=optimizer, + t_total=epoch_number + 1, + warmup_multiplier=0.1, + warmup_steps=0, + ) if local_rank == 0: writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "Events")) @@ -151,10 +189,12 @@ def train(): max_prompt=10, drop_label_prob=1, drop_point_prob=0, - ) + ) skip_update = torch.zeros(1, device=device) if point is None: - print(f"Iteration skipped due to None prompts at {batch['filename']}") + print( + f"Iteration skipped due to None prompts at {batch['filename']}" + ) skip_update = torch.ones(1, device=device) if world_size > 1: dist.all_reduce(skip_update, op=dist.ReduceOp.SUM) @@ -162,11 +202,9 @@ def train(): continue # some rank has no foreground, skip this batch optimizer.zero_grad() outputs = model( - input_images=inputs, - point_coords=point, - point_labels=point_label + input_images=inputs, point_coords=point, point_labels=point_label ) - if local_rank==0 and step % 50 == 0: + if local_rank == 0 and step % 50 == 0: plot_to_tensorboard(writer, step, inputs, labels, point, outputs) loss, loss_n = torch.tensor(0.0, device=device), torch.tensor( @@ -178,20 +216,31 @@ def train(): continue # skip background class loss_n += 1.0 gt = labels == prompt_class[idx] - loss += monai.losses.DiceCELoss(include_background=False, sigmoid=True, smooth_dr=1.0e-05, - smooth_nr=0, softmax=False, squared_pred=True, - to_onehot_y=False)(outputs[[idx]].float(), gt.float()) + loss += monai.losses.DiceCELoss( + include_background=False, + sigmoid=True, + smooth_dr=1.0e-05, + smooth_nr=0, + softmax=False, + squared_pred=True, + to_onehot_y=False, + )(outputs[[idx]].float(), gt.float()) loss /= max(loss_n, 1.0) print(loss) loss.backward() optimizer.step() step += 1 if local_rank == 0: - writer.add_scalar('loss', loss.item(), step) + writer.add_scalar("loss", loss.item(), step) if local_rank == 0 and epoch % 5 == 0: checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch}.pth") - torch.save({'model': model.state_dict(), 'epoch': epoch, 'step':step}, checkpoint_path) - print(f"Rank {local_rank}, Epoch {epoch}, Loss: {loss.item()}, Checkpoint saved: {checkpoint_path}") + torch.save( + {"model": model.state_dict(), "epoch": epoch, "step": step}, + checkpoint_path, + ) + print( + f"Rank {local_rank}, Epoch {epoch}, Loss: {loss.item()}, Checkpoint saved: {checkpoint_path}" + ) lr_scheduler.step() dist.destroy_process_group() @@ -199,4 +248,4 @@ def train(): if __name__ == "__main__": train() - # torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py \ No newline at end of file + # torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py diff --git a/vista3d/cvpr_workshop/update_ckpt.py b/vista3d/cvpr_workshop/update_ckpt.py index d798f66..3cc783b 100755 --- a/vista3d/cvpr_workshop/update_ckpt.py +++ b/vista3d/cvpr_workshop/update_ckpt.py @@ -1,16 +1,20 @@ -import torch import argparse +import torch + + def remove_module_prefix(input_pth, output_pth): # Load the checkpoint - checkpoint = torch.load(input_pth, map_location="cpu")['model'] - + checkpoint = torch.load(input_pth, map_location="cpu")["model"] + # Modify the state_dict to remove 'module.' prefix new_state_dict = {} for key, value in checkpoint.items(): if isinstance(value, dict) and "state_dict" in value: # If the checkpoint contains a 'state_dict' key (common in some saved models) - new_state_dict = {k.replace("module.", ""): v for k, v in value["state_dict"].items()} + new_state_dict = { + k.replace("module.", ""): v for k, v in value["state_dict"].items() + } value["state_dict"] = new_state_dict torch.save(value, output_pth) print(f"Updated weights saved to {output_pth}") @@ -19,15 +23,20 @@ def remove_module_prefix(input_pth, output_pth): new_state_dict[key.replace("module.", "")] = value else: new_state_dict[key] = value - + # Save the modified weights torch.save(new_state_dict, output_pth) print(f"Updated weights saved to {output_pth}") + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Remove 'module.' prefix from PyTorch weights") + parser = argparse.ArgumentParser( + description="Remove 'module.' prefix from PyTorch weights" + ) parser.add_argument("--input", required=True, help="Path to input .pth file") - parser.add_argument("--output", required=True, help="Path to save the modified .pth file") + parser.add_argument( + "--output", required=True, help="Path to save the modified .pth file" + ) args = parser.parse_args() remove_module_prefix(args.input, args.output)