diff --git a/tools/README.md b/tools/README.md
index cb50aca..06dad6f 100644
--- a/tools/README.md
+++ b/tools/README.md
@@ -1,5 +1,52 @@
# Semi-supervised VOS Inference
+This repository contains tools for semi-supervised video object segmentation (VOS) using SAM2. It includes a Gradio web interface for interactive segmentation and a command-line script for batch processing.
+
+## Gradio Web Interface
+
+To run the interactive Gradio web interface:
+
+1. Ensure you have all the required dependencies installed.
+2. Run the following command:
+
+```bash
+python tools/gradio_app.py
+```
+
+3. Open the provided URL in your web browser.
+4. Use the interface to:
+ - Upload a video
+ - Select points on the first frame to indicate the object to track
+ - Process the video to generate a masked output
+
+## Video Inference Script
+
+For batch processing or command-line usage, use the `video_inference.py` script:
+
+```bash
+python tools/video_inference.py \
+ --video_path /path/to/input/video.mp4 \
+ --points 100 150 200 250 \
+ --output_video /path/to/output/video.mp4 \
+ --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
+ --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
+ --num_pathway 3 \
+ --iou_thre 0.1 \
+ --uncertainty 2
+```
+
+Arguments:
+- `--video_path`: Path to the input video file
+- `--points`: List of x,y coordinates for initial points (format: x1 y1 x2 y2 ...)
+- `--output_video`: Path for the output video with mask overlay
+- `--sam2_cfg`: Path to SAM2 config file
+- `--sam2_checkpoint`: Path to SAM2 checkpoint file
+- `--num_pathway`: Number of segmentation pathways (default: 3)
+- `--iou_thre`: IoU threshold for filtering masks (default: 0.1)
+- `--uncertainty`: Uncertainty threshold for mask selection (default: 2)
+
+## Dataset Evaluation
+
The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [LVOS](https://lingyihongfd.github.io/lvos.github.io/), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset.
After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`.
diff --git a/tools/gradio_app.py b/tools/gradio_app.py
new file mode 100644
index 0000000..2a8fe6a
--- /dev/null
+++ b/tools/gradio_app.py
@@ -0,0 +1,222 @@
+import os
+import cv2
+import numpy as np
+import torch
+import gradio as gr
+from PIL import Image
+
+from sam2.build_sam import build_sam2
+from sam2.build_sam import build_sam2_video_predictor
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+def get_first_frame(video_path):
+ """Extract and return the first frame of the video"""
+ if not video_path:
+ return None
+
+ cap = cv2.VideoCapture(video_path)
+ ret, frame = cap.read()
+ cap.release()
+
+ if not ret:
+ return None
+
+ # Convert BGR to RGB
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ return frame_rgb
+
+class VOS:
+ def __init__(self, sam2_cfg, sam2_checkpoint, device="cuda"):
+ self.sam2_model = build_sam2(sam2_cfg, sam2_checkpoint, device=device)
+ self.image_predictor = SAM2ImagePredictor(self.sam2_model)
+
+ hydra_overrides_extra = ["++model.non_overlap_masks=false"]
+ self.video_predictor = build_sam2_video_predictor(
+ config_file=sam2_cfg,
+ ckpt_path=sam2_checkpoint,
+ apply_postprocessing=True,
+ hydra_overrides_extra=hydra_overrides_extra,
+ device=device
+ )
+
+ def process_video(self, video_path, points, num_pathway=3, iou_thre=0.1, uncertainty=2):
+ temp_dir = os.path.join(os.path.dirname(video_path), "temp_frames")
+ os.makedirs(temp_dir, exist_ok=True)
+
+ cap = cv2.VideoCapture(video_path)
+ ret, first_frame = cap.read()
+ if not ret:
+ raise ValueError("Could not read video")
+
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
+ height, width = first_frame.shape[:2]
+ fps = cap.get(cv2.CAP_PROP_FPS)
+
+ self.image_predictor.set_image(first_frame_rgb)
+ input_points = np.array(points)
+ input_labels = np.ones(len(points))
+ masks, scores, logits = self.image_predictor.predict(
+ point_coords=input_points,
+ point_labels=input_labels,
+ multimask_output=True,
+ )
+
+ best_mask = masks[scores.argmax()]
+
+ frame_count = 0
+ current_frame = first_frame
+
+ frame_name = f"{frame_count:05d}.jpg"
+ frame_path = os.path.join(temp_dir, frame_name)
+ cv2.imwrite(frame_path, current_frame)
+ frame_count += 1
+
+ while True:
+ ret, current_frame = cap.read()
+ if not ret:
+ break
+
+ frame_name = f"{frame_count:05d}.jpg"
+ frame_path = os.path.join(temp_dir, frame_name)
+ cv2.imwrite(frame_path, current_frame)
+ frame_count += 1
+
+ cap.release()
+
+ inference_state = self.video_predictor.init_state(
+ video_path=temp_dir,
+ async_loading_frames=False
+ )
+
+ inference_state['num_pathway'] = num_pathway
+ inference_state['iou_thre'] = iou_thre
+ inference_state['uncertainty'] = uncertainty
+
+ self.video_predictor.add_new_mask(
+ inference_state=inference_state,
+ frame_idx=0,
+ obj_id=1,
+ mask=best_mask
+ )
+
+ all_masks = []
+ for frame_idx, obj_ids, mask_logits in self.video_predictor.propagate_in_video(inference_state):
+ mask = (mask_logits[0] > 0).cpu().numpy()
+ all_masks.append(mask)
+
+ import shutil
+ shutil.rmtree(temp_dir)
+
+ return all_masks, fps, (height, width)
+
+def process_click(image, evt: gr.SelectData):
+ """Process click events on the image"""
+ points = getattr(process_click, 'points', [])
+ points.append(evt.index)
+
+ # Create visualization of points
+ img = np.copy(image)
+ for point in points:
+ cv2.circle(img, point, 5, (0, 255, 0), -1)
+
+ return img, points
+
+def clear_points(image):
+ """Clear all points from the image"""
+ process_click.points = []
+ return image, []
+
+def process_video(video_path, points, sam2_cfg, sam2_checkpoint):
+ """Process video with the given points"""
+ if not points:
+ return None, "Please select at least one point on the first frame"
+
+ vos = VOS(sam2_cfg, sam2_checkpoint, device=DEVICE)
+ points = np.array(points)
+
+ try:
+ masks, fps, (height, width) = vos.process_video(video_path, points)
+
+ # Create output video with mask overlay
+ cap = cv2.VideoCapture(video_path)
+ output_path = "output_masked.mp4"
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+
+ for mask in masks:
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ mask_vis = np.zeros_like(frame)
+ mask_vis[:,:,1] = mask.astype(np.uint8) * 255
+ frame_with_mask = cv2.addWeighted(frame, 1, mask_vis, 0.5, 0)
+ out.write(frame_with_mask)
+
+ cap.release()
+ out.release()
+
+ return output_path, "Processing complete"
+ except Exception as e:
+ return None, f"Error processing video: {str(e)}"
+
+def create_interface():
+ with gr.Blocks() as demo:
+ gr.HTML("
Video Object Segmentation with SAM2
")
+ gr.Image("img/logo.png", show_label=False, width=50)
+ gr.Markdown("1. Upload a video\n2. Click points on the object you want to track\n3. Click Process to generate the masked video")
+
+ with gr.Row():
+ with gr.Column():
+ video_input = gr.Video(label="Input Video")
+ first_frame = gr.Image(label="Select Points on First Frame", interactive=True, type="numpy")
+ points_state = gr.State([])
+
+ with gr.Row():
+ clear_btn = gr.Button("Clear Points")
+ process_btn = gr.Button("Process Video", variant="primary")
+
+ with gr.Column():
+ video_output = gr.Video(label="Output Video")
+ status = gr.Textbox(label="Status")
+
+ # When video is uploaded, extract and show first frame
+ video_input.change(
+ fn=get_first_frame,
+ inputs=[video_input],
+ outputs=[first_frame]
+ )
+
+ # Handle point selection
+ first_frame.select(
+ process_click,
+ inputs=[first_frame],
+ outputs=[first_frame, points_state]
+ )
+
+ # Clear points
+ clear_btn.click(
+ clear_points,
+ inputs=[first_frame],
+ outputs=[first_frame, points_state]
+ )
+
+ # Process video
+ process_btn.click(
+ process_video,
+ inputs=[
+ video_input,
+ points_state,
+ gr.Textbox(value="configs/sam2.1/sam2.1_hiera_b+.yaml", visible=False),
+ gr.Textbox(value="./checkpoints/sam2.1_hiera_base_plus.pt", visible=False)
+ ],
+ outputs=[video_output, status]
+ )
+
+ return demo
+
+if __name__ == "__main__":
+ demo = create_interface()
+ demo.launch()
\ No newline at end of file
diff --git a/tools/video_inference.py b/tools/video_inference.py
new file mode 100644
index 0000000..5370ca3
--- /dev/null
+++ b/tools/video_inference.py
@@ -0,0 +1,179 @@
+import argparse
+import os
+import cv2
+import numpy as np
+import torch
+
+from sam2.build_sam import build_sam2
+from sam2.build_sam import build_sam2_video_predictor
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+class VOS:
+ def __init__(self, sam2_cfg, sam2_checkpoint, device="cuda"):
+ """Initialize both image and video predictors"""
+ # Initialize image predictor for first frame
+ self.sam2_model = build_sam2(sam2_cfg, sam2_checkpoint, device=device)
+ self.image_predictor = SAM2ImagePredictor(self.sam2_model)
+
+ # Initialize video predictor for tracking
+ hydra_overrides_extra = ["++model.non_overlap_masks=false"]
+ self.video_predictor = build_sam2_video_predictor(
+ config_file=sam2_cfg,
+ ckpt_path=sam2_checkpoint,
+ apply_postprocessing=True,
+ hydra_overrides_extra=hydra_overrides_extra,
+ device=device
+ )
+
+ def process_video(self, video_path, points, num_pathway=3, iou_thre=0.1, uncertainty=2):
+ """
+ Process a video with given initial points
+
+ Args:
+ video_path: Path to the video file
+ points: List of [x,y] coordinates to track
+ num_pathway: Number of segmentation pathways
+ iou_thre: IoU threshold for filtering masks
+ uncertainty: Uncertainty threshold for mask selection
+ """
+ # Create temporary directory for frames
+ temp_dir = os.path.join(os.path.dirname(video_path), "temp_frames")
+ os.makedirs(temp_dir, exist_ok=True)
+
+ # Extract first frame and get mask using points
+ cap = cv2.VideoCapture(video_path)
+ ret, first_frame = cap.read()
+ if not ret:
+ raise ValueError("Could not read video")
+
+ # Convert BGR to RGB
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
+ height, width = first_frame.shape[:2]
+
+ # Get initial mask using SAM2 image predictor
+ self.image_predictor.set_image(first_frame_rgb)
+ input_points = np.array(points)
+ input_labels = np.ones(len(points)) # Assuming all points are positive
+ masks, scores, logits = self.image_predictor.predict(
+ point_coords=input_points,
+ point_labels=input_labels,
+ multimask_output=True,
+ )
+
+ # Get best mask
+ best_mask = masks[scores.argmax()]
+
+ # Save frames as images (needed for video predictor)
+ frame_count = 0
+ current_frame = first_frame # Start with the first frame
+
+ # Save the first frame
+ frame_name = f"{frame_count:05d}.jpg"
+ frame_path = os.path.join(temp_dir, frame_name)
+ cv2.imwrite(frame_path, current_frame)
+ frame_count += 1
+
+ # Save the rest of the frames
+ while True:
+ ret, current_frame = cap.read()
+ if not ret:
+ break
+
+ frame_name = f"{frame_count:05d}.jpg"
+ frame_path = os.path.join(temp_dir, frame_name)
+ cv2.imwrite(frame_path, current_frame)
+ frame_count += 1
+
+ cap.release()
+
+ # Initialize video predictor state
+ inference_state = self.video_predictor.init_state(
+ video_path=temp_dir,
+ async_loading_frames=False
+ )
+
+ # Set parameters
+ inference_state['num_pathway'] = num_pathway
+ inference_state['iou_thre'] = iou_thre
+ inference_state['uncertainty'] = uncertainty
+
+ # Add initial mask
+ self.video_predictor.add_new_mask(
+ inference_state=inference_state,
+ frame_idx=0,
+ obj_id=1, # Single object tracking
+ mask=best_mask
+ )
+
+ # Propagate through video
+ all_masks = []
+ for frame_idx, obj_ids, mask_logits in self.video_predictor.propagate_in_video(inference_state):
+ mask = (mask_logits[0] > 0).cpu().numpy() # First mask only
+ all_masks.append(mask)
+
+ import shutil
+ shutil.rmtree(temp_dir)
+
+ return all_masks
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--video_path", type=str, required=True, help="Path to input video")
+ parser.add_argument("--points", type=float, nargs='+', required=True,
+ help="List of x,y coordinates. Format: x1 y1 x2 y2 ...")
+ parser.add_argument("--output_video", type=str, required=True, help="Path to output video")
+ parser.add_argument("--sam2_cfg", type=str,
+ default="configs/sam2.1/sam2.1_hiera_b+.yaml",
+ help="SAM2 config path")
+ parser.add_argument("--sam2_checkpoint", type=str,
+ default="./checkpoints/sam2.1_hiera_base_plus.pt",
+ help="SAM2 checkpoint path")
+ parser.add_argument("--num_pathway", type=int, default=3)
+ parser.add_argument("--iou_thre", type=float, default=0.1)
+ parser.add_argument("--uncertainty", type=float, default=2)
+
+ args = parser.parse_args()
+
+ points = np.array(args.points).reshape(-1, 2)
+
+ vos = VOS(args.sam2_cfg, args.sam2_checkpoint, device=DEVICE)
+
+ # Process video
+ masks = vos.process_video(
+ args.video_path,
+ points,
+ num_pathway=args.num_pathway,
+ iou_thre=args.iou_thre,
+ uncertainty=args.uncertainty
+ )
+
+ # Create visualization video
+ cap = cv2.VideoCapture(args.video_path)
+ ret, frame = cap.read()
+ height, width = frame.shape[:2]
+ fps = cap.get(cv2.CAP_PROP_FPS) # Get original video FPS
+
+ # Initialize video writer with original FPS
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ out = cv2.VideoWriter(args.output_video, fourcc, fps, (width, height))
+
+ # Process each frame
+ for i, mask in enumerate(masks):
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ # Apply mask overlay
+ mask_vis = np.zeros_like(frame)
+ mask_vis[:,:,1] = mask.astype(np.uint8) * 255 # Green channel
+ frame_with_mask = cv2.addWeighted(frame, 1, mask_vis, 0.5, 0)
+
+ out.write(frame_with_mask)
+
+ cap.release()
+ out.release()
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file