-
Notifications
You must be signed in to change notification settings - Fork 106
Description
Bug Report: sample_with_top_p Batch Size Mismatch in Fast-dLLM v2
Summary
The sample_with_top_p method in the Fast-dLLM v2 model contains a critical bug that causes an IndexError when using batch sizes greater than 1 with temperature > 0. The function only samples from the first batch element, returning a tensor with batch_size=1 regardless of the actual input batch size.
Affected Component
File: modeling.py (HuggingFace cached model)
Location: /home/xiaoyou/.cache/huggingface/modules/transformers_modules/Efficient_hyphen_Large_hyphen_Model/Fast_dLLM_v2_7B/200e3eff9223d719e97e561c2291566d9b1cc28d/modeling.py
Method: sample_with_top_p (lines 753-785)
Model: Efficient-Large-Model/Fast_dLLM_v2_7B
Error Message
IndexError: The shape of the mask [2, 32] at index 0 does not match the shape of the indexed tensor [1, 32] at index 0
Root Cause Analysis
Buggy Code (Line 783)
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
# ... (lines 753-781)
p_1t = normalized_probs
# BUG: Only samples from p_1t[0], ignoring other batch elements
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1) # <-- BUG HERE
return x_1, p_1tProblem Explanation
The line torch.multinomial(p_1t[0], num_samples=1) only samples from p_1t[0] (the first element in the batch dimension), then reshapes it to [1, seq_len]. This causes:
| Tensor | Expected Shape | Actual Shape |
|---|---|---|
p_1t |
[batch_size, seq_len, vocab_size] |
[batch_size, seq_len, vocab_size] ✓ |
x_1 |
[batch_size, seq_len] |
[1, seq_len] ✗ |
Downstream Failure
In generation_functions.py (line 292), the code attempts:
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]Where:
x_t[:, start:end]has shape[batch_size, block_size](e.g.,[2, 32])unmask_idxhas shape[batch_size, block_size](e.g.,[2, 32])x_1has shape[1, block_size](e.g.,[1, 32]) - WRONG!
PyTorch raises IndexError because the boolean mask shape [2, 32] doesn't match the tensor shape [1, 32].
Reproduction Steps
Minimal Reproduction
import torch
import torch.nn.functional as F
def sample_with_top_p_buggy(logits, top_p=0.95, temperature=1.0):
"""Buggy version from the model"""
scaled_logits = logits / temperature
probs = F.softmax(scaled_logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
probs[indices_to_remove] = 0
probs_sum = torch.sum(probs, dim=-1, keepdim=True)
normalized_probs = probs / probs_sum
p_1t = normalized_probs
# BUG: Only samples from first batch element
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
return x_1, p_1t
# Test
batch_size, seq_len, vocab_size = 2, 32, 100
logits = torch.randn(batch_size, seq_len, vocab_size)
x_1, p_1t = sample_with_top_p_buggy(logits, top_p=0.9, temperature=1.0)
print(f"p_1t shape: {p_1t.shape}") # [2, 32, 100] ✓
print(f"x_1 shape: {x_1.shape}") # [1, 32] ✗ (should be [2, 32])
# This causes the IndexError
unmask_idx = torch.ones(batch_size, seq_len, dtype=torch.bool)
x_1[unmask_idx] # IndexError!Full Reproduction with Model
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import types
from generation_functions import Fast_dLLM_QwenForCausalLM
model_name = "Efficient-Large-Model/Fast_dLLM_v2_7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
device = "cuda:0"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map=device,
trust_remote_code=True
)
model.batch_sample = types.MethodType(Fast_dLLM_QwenForCausalLM.batch_sample, model)
# Batch of 2 inputs triggers the bug
input = ["what is the meaning of life", "what is the meaning of brushing my teeth"]
tokenized = tokenizer(input, padding=False)
seq_len = torch.tensor([len(ids) for ids in tokenized.input_ids], device=device)
min_len = seq_len.min().item()
input_ids = tokenizer(input, return_tensors="pt", padding=True).input_ids.to(device)
# This will raise IndexError when temperature > 0
finished_samples, steps_per_sample = model.batch_sample(
input_ids=input_ids,
tokenizer=tokenizer,
block_size=128,
max_new_tokens=512,
small_block_size=32,
min_len=min_len,
seq_len=seq_len,
mask_id=151665,
threshold=0.9,
stop_token=151645,
use_block_cache=False,
top_p=0.9,
temperature=1.0, # BUG: temperature > 0 triggers the issue
)Proposed Fix
Option 1: Fix torch.multinomial to Handle Full Batch
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
if temperature > 0:
scaled_logits = logits / temperature
else:
p_1t = torch.softmax(logits, dim=-1)
x_1 = p_1t.argmax(dim=-1)
return x_1, p_1t
probs = F.softmax(scaled_logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
probs[indices_to_remove] = 0
probs_sum = torch.sum(probs, dim=-1, keepdim=True)
normalized_probs = probs / probs_sum
p_1t = normalized_probs
# FIXED: Sample from all batch elements
batch_size, seq_len, vocab_size = p_1t.shape
p_1t_flat = p_1t.view(-1, vocab_size) # [batch_size * seq_len, vocab_size]
x_1_flat = torch.multinomial(p_1t_flat, num_samples=1).squeeze(-1) # [batch_size * seq_len]
x_1 = x_1_flat.view(batch_size, seq_len) # [batch_size, seq_len]
return x_1, p_1tOption 2: Alternative Fix Using Loop
# FIXED: Sample from each batch element
x_1 = torch.stack([
torch.multinomial(p_1t[i], num_samples=1).squeeze(-1)
for i in range(p_1t.shape[0])
], dim=0)Workaround
Until the bug is fixed upstream, use temperature=0.0 (greedy decoding):
finished_samples, steps_per_sample = model.batch_sample(
# ... other params ...
temperature=0.0, # WORKAROUND: Use greedy decoding to avoid the bug
)When temperature <= 0, the function takes a different code path that correctly handles batching:
if temperature <= 0:
p_1t = torch.softmax(logits, dim=-1)
x_1 = p_1t.argmax(dim=-1) # This correctly returns [batch_size, seq_len]
return x_1, p_1tImpact
- Severity: High - Prevents batch inference with
temperature > 0 - Affected Users: Anyone using
batch_samplewithbatch_size > 1andtemperature > 0 - Functional Impact: Complete failure of batch generation
Environment
- Model:
Efficient-Large-Model/Fast_dLLM_v2_7B - Transformers Version: (check with
transformers.__version__) - PyTorch Version: (check with
torch.__version__) - Python Version: 3.10
References
- Model HuggingFace page: https://huggingface.co/Efficient-Large-Model/Fast_dLLM_v2_7B
- Buggy code location:
modeling.py:783 - Downstream failure:
generation_functions.py:292
Report Date: 2026-01-15
Reported By: Investigation via code analysis and reproduction testing