diff --git a/vista3d/cvpr_workshop/Dockerfile b/vista3d/cvpr_workshop/Dockerfile index 186d69f..418ef7d 100755 --- a/vista3d/cvpr_workshop/Dockerfile +++ b/vista3d/cvpr_workshop/Dockerfile @@ -21,4 +21,4 @@ COPY predict.sh /workspace/predict.sh RUN chmod +x /workspace/predict.sh # Set default command -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"] diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md index 6769fb9..a38091e 100644 --- a/vista3d/cvpr_workshop/README.md +++ b/vista3d/cvpr_workshop/README.md @@ -12,10 +12,9 @@ limitations under the License. --> # Overview -This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It +This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI. - It is simplified to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. The finetuned VISTA3D checkpoint on the challenge subsets is available [here](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) # Setup @@ -24,10 +23,12 @@ pip install -r requirements.txt ``` # Training -Download the challenge subsets finetuned [checkpoint](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) or VISTA3D original [checkpoint]((https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)). Generate a json list that contains your traning data and update the json file path in the script. +Download the challenge subsets finetuned [checkpoint](https://drive.google.com/file/d/1hQ8imaf4nNSg_43dYbPSJT0dr7JgAKWX/view?usp=sharing) or VISTA3D original [checkpoint]((https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)). Generate a json list that contains your traning data and update the json file path in the script. ``` torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py ``` +The checkpoint saved by train_cvpr.py can be updated by `update_ckpt.py` to remove the additional `module` key due to multi-gpu training. + # Inference You can directly download the [docker file](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) for the challenge baseline. @@ -36,6 +37,4 @@ We provide a Dockerfile to satisfy the challenge format. For more details, refer docker build -t vista3d:latest . docker save -o vista3d.tar.gz vista3d:latest ``` - - - +You can also directly run `predict.sh`. Download the finetuned checkpoint and modify the `--model=/your_downloaded_checkpoint`. Change `save_data=True` in `infer_cvpr.py` to save predictions to nifti files for visualization. diff --git a/vista3d/cvpr_workshop/infer_cvpr.py b/vista3d/cvpr_workshop/infer_cvpr.py index 7694965..dddf7f2 100755 --- a/vista3d/cvpr_workshop/infer_cvpr.py +++ b/vista3d/cvpr_workshop/infer_cvpr.py @@ -1,147 +1,159 @@ +import argparse +import glob + import monai import monai.transforms -import torch -import argparse -import numpy as np import nibabel as nib -import glob -from monai.networks.nets.vista3d import vista3d132 -from monai.utils import optional_import +import numpy as np +import torch from monai.apps.vista3d.inferer import point_based_window_inferer from monai.inferers import SlidingWindowInfererAdapt +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import tqdm, _ = optional_import("tqdm", name="tqdm") -import numpy as np -import pdb import os + + def convert_clicks(alldata): # indexes = list(alldata.keys()) # data = [alldata[i] for i in indexes] data = alldata B = len(data) # Number of objects - indexes = np.arange(1, B+1).tolist() + indexes = np.arange(1, B + 1).tolist() # Determine the maximum number of points across all objects - max_N = max(len(obj['fg']) + len(obj['bg']) for obj in data) - + max_N = max(len(obj["fg"]) + len(obj["bg"]) for obj in data) + # Initialize padded arrays point_coords = np.zeros((B, max_N, 3), dtype=int) point_labels = np.full((B, max_N), -1, dtype=int) - + for i, obj in enumerate(data): points = [] labels = [] - + # Add foreground points - for fg_point in obj['fg']: + for fg_point in obj["fg"]: points.append(fg_point) labels.append(1) - + # Add background points - for bg_point in obj['bg']: + for bg_point in obj["bg"]: points.append(bg_point) labels.append(0) - + # Fill in the arrays - point_coords[i, :len(points)] = points - point_labels[i, :len(labels)] = labels - + point_coords[i, : len(points)] = points + point_labels[i, : len(labels)] = labels + return point_coords, point_labels, indexes -if __name__ == '__main__': +if __name__ == "__main__": # set to true to save nifti files for visualization save_data = False - point_inferer = True # use point based inferen - roi_size = [128,128,128] + point_inferer = True # use point based inferen + roi_size = [128, 128, 128] parser = argparse.ArgumentParser() - parser.add_argument("--test_img_path", type=str, default='./tests') - parser.add_argument("--save_path", type=str, default='./outputs/') - parser.add_argument("--model", type=str, default='checkpoints/model_final.pth') + parser.add_argument("--test_img_path", type=str, default="./tests") + parser.add_argument("--save_path", type=str, default="./outputs/") + parser.add_argument("--model", type=str, default="checkpoints/model_final.pth") args = parser.parse_args() - os.makedirs(args.save_path,exist_ok=True) + os.makedirs(args.save_path, exist_ok=True) # load model checkpoint_path = args.model model = vista3d132(in_channels=1) - pretrained_ckpt = torch.load(checkpoint_path, map_location='cuda') + pretrained_ckpt = torch.load(checkpoint_path, map_location="cuda") model.load_state_dict(pretrained_ckpt, strict=True) - + # load data test_cases = glob.glob(os.path.join(args.test_img_path, "*.npz")) for img_path in test_cases: case_name = os.path.basename(img_path) print(case_name) img = np.load(img_path, allow_pickle=True) - img_array = img['imgs'] - spacing = img['spacing'] + img_array = img["imgs"] + spacing = img["spacing"] original_shape = img_array.shape affine = np.diag(spacing.tolist() + [1]) # 4x4 affine matrix if save_data: # Create a NIfTI image nifti_img = nib.Nifti1Image(img_array, affine) # Save the NIfTI file - nib.save(nifti_img, img_path.replace('.npz','.nii.gz')) - nifti_img = nib.Nifti1Image(img['gts'], affine) + nib.save(nifti_img, img_path.replace(".npz", ".nii.gz")) + nifti_img = nib.Nifti1Image(img["gts"], affine) # Save the NIfTI file - nib.save(nifti_img, img_path.replace('.npz','gts.nii.gz')) - clicks = img.get('clicks', [{'fg':[[418, 138, 136]], 'bg':[]}]) + nib.save(nifti_img, img_path.replace(".npz", "gts.nii.gz")) + clicks = img.get("clicks", [{"fg": [[418, 138, 136]], "bg": []}]) point_coords, point_labels, indexes = convert_clicks(clicks) # preprocess img_array = torch.from_numpy(img_array) - img_array = img_array.unsqueeze(0) - img_array = monai.transforms.ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0, b_max=1, clip=True)(img_array) - img_array = img_array.unsqueeze(0) # add channel dim - device = 'cuda' + img_array = img_array.unsqueeze(0) + img_array = monai.transforms.ScaleIntensityRangePercentiles( + lower=1, upper=99, b_min=0, b_max=1, clip=True + )(img_array) + img_array = img_array.unsqueeze(0) # add channel dim + device = "cuda" # slidingwindow with torch.no_grad(): if not point_inferer: - model.NINF_VALUE = 0 # set to 0 in case sliding window is used. + model.NINF_VALUE = 0 # set to 0 in case sliding window is used. # directly using slidingwindow inferer is not optimal. - val_outputs = SlidingWindowInfererAdapt( - roi_size=roi_size, sw_batch_size=1, with_coord=True, padding_mode="replicate" - )( - inputs=img_array.to(device), - transpose=True, - network=model.to(device), - point_coords=torch.from_numpy(point_coords).to(device), - point_labels=torch.from_numpy(point_labels).to(device) - )[0] > 0 + val_outputs = ( + SlidingWindowInfererAdapt( + roi_size=roi_size, + sw_batch_size=1, + with_coord=True, + padding_mode="replicate", + )( + inputs=img_array.to(device), + transpose=True, + network=model.to(device), + point_coords=torch.from_numpy(point_coords).to(device), + point_labels=torch.from_numpy(point_labels).to(device), + )[ + 0 + ] + > 0 + ) final_outputs = torch.zeros_like(val_outputs[0], dtype=torch.float32) for i, v in enumerate(val_outputs): final_outputs += indexes[i] * v else: # point based - final_outputs = torch.zeros_like(img_array[0,0], dtype=torch.float32) + final_outputs = torch.zeros_like(img_array[0, 0], dtype=torch.float32) for i, v in enumerate(indexes): - val_outputs = point_based_window_inferer( - inputs=img_array.to(device), - roi_size=roi_size, - transpose=True, - with_coord=True, - predictor=model.to(device), - mode="gaussian", - sw_device=device, - device=device, - center_only=True, # only crop the center - point_coords=torch.from_numpy(point_coords[[i]]).to(device), - point_labels=torch.from_numpy(point_labels[[i]]).to(device) - )[0] > 0 + val_outputs = ( + point_based_window_inferer( + inputs=img_array.to(device), + roi_size=roi_size, + transpose=True, + with_coord=True, + predictor=model.to(device), + mode="gaussian", + sw_device=device, + device=device, + center_only=True, # only crop the center + point_coords=torch.from_numpy(point_coords[[i]]).to(device), + point_labels=torch.from_numpy(point_labels[[i]]).to(device), + )[0] + > 0 + ) final_outputs[val_outputs[0]] = v final_outputs = torch.nan_to_num(final_outputs) - # save data + # save data if save_data: # Create a NIfTI image - nifti_img = nib.Nifti1Image(final_outputs.to(torch.float32).data.cpu().numpy(), affine) + nifti_img = nib.Nifti1Image( + final_outputs.to(torch.float32).data.cpu().numpy(), affine + ) # Save the NIfTI file - nib.save(nifti_img, os.path.join(args.save_path, case_name.replace('.npz','.nii.gz'))) - np.savez_compressed(os.path.join(args.save_path, case_name), segs=final_outputs.to(torch.float32).data.cpu().numpy()) - - - - - - - - - - \ No newline at end of file + nib.save( + nifti_img, + os.path.join(args.save_path, case_name.replace(".npz", ".nii.gz")), + ) + np.savez_compressed( + os.path.join(args.save_path, case_name), + segs=final_outputs.to(torch.float32).data.cpu().numpy(), + ) diff --git a/vista3d/cvpr_workshop/requirements.txt b/vista3d/cvpr_workshop/requirements.txt index c57fbde..308df32 100755 --- a/vista3d/cvpr_workshop/requirements.txt +++ b/vista3d/cvpr_workshop/requirements.txt @@ -10,4 +10,4 @@ numpy scipy cupy-cuda12x cucim -tqdm \ No newline at end of file +tqdm diff --git a/vista3d/cvpr_workshop/train_cvpr.py b/vista3d/cvpr_workshop/train_cvpr.py index 25a33bc..c905e98 100755 --- a/vista3d/cvpr_workshop/train_cvpr.py +++ b/vista3d/cvpr_workshop/train_cvpr.py @@ -1,35 +1,34 @@ -import os import json +import os +import warnings + +import monai import monai.transforms +import numpy as np import torch -import torch.nn as nn -import torch.optim as optim import torch.distributed as dist -from torch.utils.data import Dataset -from torch.nn.parallel import DistributedDataParallel as DDP -import numpy as np -import monai -from tqdm import tqdm -import pdb -from monai.networks.nets import vista3d132 +import torch.optim as optim from monai.apps.vista3d.sampler import sample_prompt_pairs +from monai.data import DataLoader +from monai.networks.nets import vista3d132 +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset from torch.utils.tensorboard import SummaryWriter -from monai.data import DataLoader, DistributedSampler -import warnings -import nibabel as nib +from tqdm import tqdm + warnings.simplefilter("ignore") # Custom dataset for .npz files import matplotlib.pyplot as plt -import torchvision.utils as vutils -NUM_PATCHES_PER_IMAGE=4 +NUM_PATCHES_PER_IMAGE = 4 + def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs): """ Plots B figures, where each figure shows the slice where the point is located and overlays the point on this slice. - + Args: writer: TensorBoard writer epoch: Current epoch number @@ -44,72 +43,107 @@ def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs): for b in range(B): fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - + # Select the first click point in (z, y, x) format x, y, z = points[b, 0].cpu().numpy().astype(int) - + # Extract the corresponding slice input_slice = inputs_np[:, :, z] # Get slice at depth z label_slice = labels_np[:, :, z] output_slice = outputs[b, 0].cpu().detach().numpy()[:, :, z] > 0 - + # Plot input with point overlay - axes[0].imshow(input_slice, cmap='gray') - axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[0].imshow(input_slice, cmap="gray") + axes[0].scatter(y, x, c="red", marker="x", s=50) axes[0].set_title(f"Input (Slice {z})") - + # Plot label - axes[1].imshow(label_slice, cmap='gray') - axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[1].imshow(label_slice, cmap="gray") + axes[0].scatter(y, x, c="red", marker="x", s=50) axes[1].set_title(f"Ground Truth (Slice {z})") - + # Plot output - axes[2].imshow(output_slice, cmap='gray') - axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[2].imshow(output_slice, cmap="gray") + axes[0].scatter(y, x, c="red", marker="x", s=50) axes[2].set_title(f"Model Output (Slice {z})") - + plt.tight_layout() - + # Log figure to TensorBoard writer.add_figure(f"Object_{b}_Segmentation", fig, epoch) plt.close(fig) + class NPZDataset(Dataset): def __init__(self, json_file): - with open(json_file, 'r') as f: + with open(json_file, "r") as f: self.file_paths = json.load(f) - self.base_path = '/workspace/VISTA/CVPR-MedSegFMCompetition/trainsubset' + self.base_path = "/workspace/VISTA/CVPR-MedSegFMCompetition/trainsubset" + def __len__(self): return len(self.file_paths) def __getitem__(self, idx): img = np.load(os.path.join(self.base_path, self.file_paths[idx])) - img_array = torch.from_numpy(img['imgs']).unsqueeze(0).to(torch.float32) - label = torch.from_numpy(img['gts']).unsqueeze(0).to(torch.int32) - data = {"image": img_array, "label": label, 'filename': self.file_paths[idx]} - affine = np.diag(img['spacing'].tolist() + [1]) # 4x4 affine matrix - transforms = monai.transforms.Compose([ - monai.transforms.ScaleIntensityRangePercentilesd(keys="image", lower=1, upper=99, b_min=0, b_max=1, clip=True), - monai.transforms.SpatialPadd(mode=["constant", "constant"], keys=["image", "label"], spatial_size=[128, 128, 128]), - monai.transforms.RandCropByLabelClassesd(spatial_size=[128, 128, 128], keys=["image", "label"], label_key="label",num_classes=label.max() + 1, num_samples=NUM_PATCHES_PER_IMAGE), - monai.transforms.RandScaleIntensityd(factors=0.2, prob=0.2, keys="image"), - monai.transforms.RandShiftIntensityd(offsets=0.2, prob=0.2, keys="image"), - monai.transforms.RandGaussianNoised(mean=0., std=0.2, prob=0.2, keys="image"), - monai.transforms.RandFlipd(spatial_axis=0, prob=0.2, keys=["image", "label"]), - monai.transforms.RandFlipd(spatial_axis=1, prob=0.2, keys=["image", "label"]), - monai.transforms.RandFlipd(spatial_axis=2, prob=0.2, keys=["image", "label"]), - monai.transforms.RandRotate90d(max_k=3, prob=0.2, keys=["image", "label"]) - ]) + img_array = torch.from_numpy(img["imgs"]).unsqueeze(0).to(torch.float32) + label = torch.from_numpy(img["gts"]).unsqueeze(0).to(torch.int32) + data = {"image": img_array, "label": label, "filename": self.file_paths[idx]} + affine = np.diag(img["spacing"].tolist() + [1]) # 4x4 affine matrix + transforms = monai.transforms.Compose( + [ + monai.transforms.ScaleIntensityRangePercentilesd( + keys="image", lower=1, upper=99, b_min=0, b_max=1, clip=True + ), + monai.transforms.SpatialPadd( + mode=["constant", "constant"], + keys=["image", "label"], + spatial_size=[128, 128, 128], + ), + monai.transforms.RandCropByLabelClassesd( + spatial_size=[128, 128, 128], + keys=["image", "label"], + label_key="label", + num_classes=label.max() + 1, + num_samples=NUM_PATCHES_PER_IMAGE, + ), + monai.transforms.RandScaleIntensityd( + factors=0.2, prob=0.2, keys="image" + ), + monai.transforms.RandShiftIntensityd( + offsets=0.2, prob=0.2, keys="image" + ), + monai.transforms.RandGaussianNoised( + mean=0.0, std=0.2, prob=0.2, keys="image" + ), + monai.transforms.RandFlipd( + spatial_axis=0, prob=0.2, keys=["image", "label"] + ), + monai.transforms.RandFlipd( + spatial_axis=1, prob=0.2, keys=["image", "label"] + ), + monai.transforms.RandFlipd( + spatial_axis=2, prob=0.2, keys=["image", "label"] + ), + monai.transforms.RandRotate90d( + max_k=3, prob=0.2, keys=["image", "label"] + ), + ] + ) data = transforms(data) return data + + # Training function def train(): + json_file = "subset.json" # Update with your JSON file json_file = "subset.json" # Update with your JSON file epoch_number = 100 start_epoch = 0 lr = 2e-5 checkpoint_dir = "checkpoints" - start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth' + start_checkpoint = "/workspace/CPRR25_vista3D_model_final_10percent_data.pth" + start_checkpoint = "/workspace/CPRR25_vista3D_model_final_10percent_data.pth" + os.makedirs(checkpoint_dir, exist_ok=True) dist.init_process_group(backend="nccl") world_size = int(os.environ["WORLD_SIZE"]) @@ -117,15 +151,24 @@ def train(): torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") dataset = NPZDataset(json_file) - sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=local_rank + ) dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32) model = vista3d132(in_channels=1).to(device) pretrained_ckpt = torch.load(start_checkpoint, map_location=device) # pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) + pretrained_ckpt = torch.load(start_checkpoint, map_location=device) + # pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) - model.load_state_dict(pretrained_ckpt['model'], strict=True) + model.load_state_dict(pretrained_ckpt["model"], strict=True) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05) - lr_scheduler = monai.optimizers.WarmupCosineSchedule(optimizer=optimizer, t_total= epoch_number+1, warmup_multiplier=0.1, warmup_steps=0) + lr_scheduler = monai.optimizers.WarmupCosineSchedule( + optimizer=optimizer, + t_total=epoch_number + 1, + warmup_multiplier=0.1, + warmup_steps=0, + ) if local_rank == 0: writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "Events")) @@ -146,10 +189,12 @@ def train(): max_prompt=10, drop_label_prob=1, drop_point_prob=0, - ) + ) skip_update = torch.zeros(1, device=device) if point is None: - print(f"Iteration skipped due to None prompts at {batch['filename']}") + print( + f"Iteration skipped due to None prompts at {batch['filename']}" + ) skip_update = torch.ones(1, device=device) if world_size > 1: dist.all_reduce(skip_update, op=dist.ReduceOp.SUM) @@ -157,11 +202,9 @@ def train(): continue # some rank has no foreground, skip this batch optimizer.zero_grad() outputs = model( - input_images=inputs, - point_coords=point, - point_labels=point_label + input_images=inputs, point_coords=point, point_labels=point_label ) - if local_rank==0 and step % 50 == 0: + if local_rank == 0 and step % 50 == 0: plot_to_tensorboard(writer, step, inputs, labels, point, outputs) loss, loss_n = torch.tensor(0.0, device=device), torch.tensor( @@ -173,20 +216,31 @@ def train(): continue # skip background class loss_n += 1.0 gt = labels == prompt_class[idx] - loss += monai.losses.DiceCELoss(include_background=False, sigmoid=True, smooth_dr=1.0e-05, - smooth_nr=0, softmax=False, squared_pred=True, - to_onehot_y=False)(outputs[[idx]].float(), gt.float()) + loss += monai.losses.DiceCELoss( + include_background=False, + sigmoid=True, + smooth_dr=1.0e-05, + smooth_nr=0, + softmax=False, + squared_pred=True, + to_onehot_y=False, + )(outputs[[idx]].float(), gt.float()) loss /= max(loss_n, 1.0) print(loss) loss.backward() optimizer.step() step += 1 if local_rank == 0: - writer.add_scalar('loss', loss.item(), step) + writer.add_scalar("loss", loss.item(), step) if local_rank == 0 and epoch % 5 == 0: checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch}.pth") - torch.save({'model': model.state_dict(), 'epoch': epoch, 'step':step}, checkpoint_path) - print(f"Rank {local_rank}, Epoch {epoch}, Loss: {loss.item()}, Checkpoint saved: {checkpoint_path}") + torch.save( + {"model": model.state_dict(), "epoch": epoch, "step": step}, + checkpoint_path, + ) + print( + f"Rank {local_rank}, Epoch {epoch}, Loss: {loss.item()}, Checkpoint saved: {checkpoint_path}" + ) lr_scheduler.step() dist.destroy_process_group() @@ -194,4 +248,4 @@ def train(): if __name__ == "__main__": train() - # torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py \ No newline at end of file + # torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py diff --git a/vista3d/cvpr_workshop/update_ckpt.py b/vista3d/cvpr_workshop/update_ckpt.py index d798f66..3cc783b 100755 --- a/vista3d/cvpr_workshop/update_ckpt.py +++ b/vista3d/cvpr_workshop/update_ckpt.py @@ -1,16 +1,20 @@ -import torch import argparse +import torch + + def remove_module_prefix(input_pth, output_pth): # Load the checkpoint - checkpoint = torch.load(input_pth, map_location="cpu")['model'] - + checkpoint = torch.load(input_pth, map_location="cpu")["model"] + # Modify the state_dict to remove 'module.' prefix new_state_dict = {} for key, value in checkpoint.items(): if isinstance(value, dict) and "state_dict" in value: # If the checkpoint contains a 'state_dict' key (common in some saved models) - new_state_dict = {k.replace("module.", ""): v for k, v in value["state_dict"].items()} + new_state_dict = { + k.replace("module.", ""): v for k, v in value["state_dict"].items() + } value["state_dict"] = new_state_dict torch.save(value, output_pth) print(f"Updated weights saved to {output_pth}") @@ -19,15 +23,20 @@ def remove_module_prefix(input_pth, output_pth): new_state_dict[key.replace("module.", "")] = value else: new_state_dict[key] = value - + # Save the modified weights torch.save(new_state_dict, output_pth) print(f"Updated weights saved to {output_pth}") + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Remove 'module.' prefix from PyTorch weights") + parser = argparse.ArgumentParser( + description="Remove 'module.' prefix from PyTorch weights" + ) parser.add_argument("--input", required=True, help="Path to input .pth file") - parser.add_argument("--output", required=True, help="Path to save the modified .pth file") + parser.add_argument( + "--output", required=True, help="Path to save the modified .pth file" + ) args = parser.parse_args() remove_module_prefix(args.input, args.output)