Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,17 @@ def load_video_qwen_vl_utils(
if self.config.video_sampling_strategy == "frame_num":
n_frames = self.config.frame_num
video_dict["nframes"] = n_frames
video_inputs, sample_fps = fetch_video(video_dict, return_video_sample_fps=True, return_video_metadata=True)
video_inputs, sample_fps = fetch_video(
video_dict, image_patch_size=16, return_video_sample_fps=True, return_video_metadata=True
)
frames, video_metadata = video_inputs
frames = frames.numpy()
return frames, video_metadata, sample_fps
elif self.config.video_sampling_strategy == "fps":
video_dict["fps"] = fps
video_inputs, sample_fps = fetch_video(video_dict, return_video_sample_fps=True, return_video_metadata=True)
video_inputs, sample_fps = fetch_video(
video_dict, image_patch_size=16, return_video_sample_fps=True, return_video_metadata=True
)
frames, video_metadata = video_inputs
frames = frames.numpy()
return frames, video_metadata, sample_fps
Expand Down
8 changes: 6 additions & 2 deletions src/lmms_engine/datasets/naive/qwen3_vl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,17 @@ def load_video_qwen_vl_utils(
if self.config.video_sampling_strategy == "frame_num":
n_frames = self.config.frame_num
video_dict["nframes"] = n_frames
video_inputs, sample_fps = fetch_video(video_dict, return_video_sample_fps=True, return_video_metadata=True)
video_inputs, sample_fps = fetch_video(
video_dict, image_patch_size=16, return_video_sample_fps=True, return_video_metadata=True
)
frames, video_metadata = video_inputs
frames = frames.numpy()
return frames, video_metadata, sample_fps
elif self.config.video_sampling_strategy == "fps":
video_dict["fps"] = fps
video_inputs, sample_fps = fetch_video(video_dict, return_video_sample_fps=True, return_video_metadata=True)
video_inputs, sample_fps = fetch_video(
video_dict, image_patch_size=16, return_video_sample_fps=True, return_video_metadata=True
)
frames, video_metadata = video_inputs
frames = frames.numpy()
return frames, video_metadata, sample_fps
Expand Down
39 changes: 18 additions & 21 deletions src/lmms_engine/datasets/processor/qwen3_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_qwen_template_labels(
curr_timestamp = self.processor._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
self.processor.video_processor.merge_size,
self.processor.video_processor.temporal_patch_size,
)
encode_id, used_video = self._expand_encode_id_video_tokens(
encode_id,
Expand Down Expand Up @@ -195,35 +195,32 @@ def _expand_encode_id_video_tokens(
prev = 0
merge_length = self.processor.video_processor.merge_size**2
for idx, pos in enumerate(video_pos):
# Before image pos, no expand
expanded_encode_id.extend(encode_id[prev:pos])
# Image pos, expand
# The original chat template produces: <|vision_start|> <|video_pad|> <|vision_end|>
# We replace this entire triplet with per-frame blocks.
# To match transformers Qwen3VLProcessor, each frame should be:
# <timestamp> <|vision_start|> <video_tokens> <|vision_end|>
# So we exclude the original <|vision_start|> (at pos-1) and <|vision_end|> (at pos+1).
expanded_encode_id.extend(encode_id[prev : pos - 1])

frame_seq_len = video_grid_thw[idx + start_from][1:].prod() // merge_length
for frame_idx in range(video_grid_thw[idx + start_from][0]):
curr_time = curr_timestamp[frame_idx]
timestamp_token = f"<{curr_time:.1f} seconds>"
timestamp_token_id = self.processor.tokenizer.encode(timestamp_token)
visual_tokens = [self.video_token_id] * frame_seq_len
# Three cases
# If first frame, the start token in being added to the expanded encode id already, no need to include
# If last frame, the end token will be added to the expanded encode id later, no need to include
# If middle frame, both start and end tokens need to be included
if frame_idx == 0:
curr_expand_video_ids = timestamp_token_id + visual_tokens + [self.processor.vision_end_token_id]
elif frame_idx == video_grid_thw[idx + start_from][0] - 1:
curr_expand_video_ids = [self.processor.vision_start_token_id] + timestamp_token_id + visual_tokens
else:
curr_expand_video_ids = (
[self.processor.vision_start_token_id]
+ timestamp_token_id
+ visual_tokens
+ [self.processor.vision_end_token_id]
)
# Each frame: <timestamp> <|vision_start|> <video_tokens> <|vision_end|>
curr_expand_video_ids = (
timestamp_token_id
+ [self.processor.vision_start_token_id]
+ visual_tokens
+ [self.processor.vision_end_token_id]
)
expanded_encode_id.extend(curr_expand_video_ids)
prev = pos + 1
# Skip past the original <|vision_end|> at pos+1
prev = pos + 2

if idx == len(video_pos) - 1:
# Last image pos, Add the rest to the end
# Last video pos, add the rest to the end
expanded_encode_id.extend(encode_id[prev:])

return expanded_encode_id, len(video_pos)
Loading