Skip to content

batch size > 1 and temperature != 0.0 not supported by the generate function file #64

@Laurence-Wu

Description

@Laurence-Wu

Bug Report: sample_with_top_p Batch Size Mismatch in Fast-dLLM v2

Summary

The sample_with_top_p method in the Fast-dLLM v2 model contains a critical bug that causes an IndexError when using batch sizes greater than 1 with temperature > 0. The function only samples from the first batch element, returning a tensor with batch_size=1 regardless of the actual input batch size.

Affected Component

File: modeling.py (HuggingFace cached model)
Location: /home/xiaoyou/.cache/huggingface/modules/transformers_modules/Efficient_hyphen_Large_hyphen_Model/Fast_dLLM_v2_7B/200e3eff9223d719e97e561c2291566d9b1cc28d/modeling.py
Method: sample_with_top_p (lines 753-785)
Model: Efficient-Large-Model/Fast_dLLM_v2_7B

Error Message

IndexError: The shape of the mask [2, 32] at index 0 does not match the shape of the indexed tensor [1, 32] at index 0

Root Cause Analysis

Buggy Code (Line 783)

def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
    # ... (lines 753-781)

    p_1t = normalized_probs
    # BUG: Only samples from p_1t[0], ignoring other batch elements
    x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)  # <-- BUG HERE

    return x_1, p_1t

Problem Explanation

The line torch.multinomial(p_1t[0], num_samples=1) only samples from p_1t[0] (the first element in the batch dimension), then reshapes it to [1, seq_len]. This causes:

Tensor Expected Shape Actual Shape
p_1t [batch_size, seq_len, vocab_size] [batch_size, seq_len, vocab_size]
x_1 [batch_size, seq_len] [1, seq_len]

Downstream Failure

In generation_functions.py (line 292), the code attempts:

x_t[:, start:end][unmask_idx] = x_1[unmask_idx]

Where:

  • x_t[:, start:end] has shape [batch_size, block_size] (e.g., [2, 32])
  • unmask_idx has shape [batch_size, block_size] (e.g., [2, 32])
  • x_1 has shape [1, block_size] (e.g., [1, 32]) - WRONG!

PyTorch raises IndexError because the boolean mask shape [2, 32] doesn't match the tensor shape [1, 32].

Reproduction Steps

Minimal Reproduction

import torch
import torch.nn.functional as F

def sample_with_top_p_buggy(logits, top_p=0.95, temperature=1.0):
    """Buggy version from the model"""
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=-1)

    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
    )
    probs[indices_to_remove] = 0

    probs_sum = torch.sum(probs, dim=-1, keepdim=True)
    normalized_probs = probs / probs_sum

    p_1t = normalized_probs
    # BUG: Only samples from first batch element
    x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)

    return x_1, p_1t

# Test
batch_size, seq_len, vocab_size = 2, 32, 100
logits = torch.randn(batch_size, seq_len, vocab_size)

x_1, p_1t = sample_with_top_p_buggy(logits, top_p=0.9, temperature=1.0)

print(f"p_1t shape: {p_1t.shape}")      # [2, 32, 100] ✓
print(f"x_1 shape: {x_1.shape}")        # [1, 32] ✗ (should be [2, 32])

# This causes the IndexError
unmask_idx = torch.ones(batch_size, seq_len, dtype=torch.bool)
x_1[unmask_idx]  # IndexError!

Full Reproduction with Model

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import types
from generation_functions import Fast_dLLM_QwenForCausalLM

model_name = "Efficient-Large-Model/Fast_dLLM_v2_7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
device = "cuda:0"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map=device,
    trust_remote_code=True
)

model.batch_sample = types.MethodType(Fast_dLLM_QwenForCausalLM.batch_sample, model)

# Batch of 2 inputs triggers the bug
input = ["what is the meaning of life", "what is the meaning of brushing my teeth"]
tokenized = tokenizer(input, padding=False)
seq_len = torch.tensor([len(ids) for ids in tokenized.input_ids], device=device)
min_len = seq_len.min().item()
input_ids = tokenizer(input, return_tensors="pt", padding=True).input_ids.to(device)

# This will raise IndexError when temperature > 0
finished_samples, steps_per_sample = model.batch_sample(
    input_ids=input_ids,
    tokenizer=tokenizer,
    block_size=128,
    max_new_tokens=512,
    small_block_size=32,
    min_len=min_len,
    seq_len=seq_len,
    mask_id=151665,
    threshold=0.9,
    stop_token=151645,
    use_block_cache=False,
    top_p=0.9,
    temperature=1.0,  # BUG: temperature > 0 triggers the issue
)

Proposed Fix

Option 1: Fix torch.multinomial to Handle Full Batch

def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
    if temperature > 0:
        scaled_logits = logits / temperature
    else:
        p_1t = torch.softmax(logits, dim=-1)
        x_1 = p_1t.argmax(dim=-1)
        return x_1, p_1t

    probs = F.softmax(scaled_logits, dim=-1)

    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
    )

    probs[indices_to_remove] = 0

    probs_sum = torch.sum(probs, dim=-1, keepdim=True)
    normalized_probs = probs / probs_sum

    p_1t = normalized_probs

    # FIXED: Sample from all batch elements
    batch_size, seq_len, vocab_size = p_1t.shape
    p_1t_flat = p_1t.view(-1, vocab_size)  # [batch_size * seq_len, vocab_size]
    x_1_flat = torch.multinomial(p_1t_flat, num_samples=1).squeeze(-1)  # [batch_size * seq_len]
    x_1 = x_1_flat.view(batch_size, seq_len)  # [batch_size, seq_len]

    return x_1, p_1t

Option 2: Alternative Fix Using Loop

    # FIXED: Sample from each batch element
    x_1 = torch.stack([
        torch.multinomial(p_1t[i], num_samples=1).squeeze(-1)
        for i in range(p_1t.shape[0])
    ], dim=0)

Workaround

Until the bug is fixed upstream, use temperature=0.0 (greedy decoding):

finished_samples, steps_per_sample = model.batch_sample(
    # ... other params ...
    temperature=0.0,  # WORKAROUND: Use greedy decoding to avoid the bug
)

When temperature <= 0, the function takes a different code path that correctly handles batching:

if temperature <= 0:
    p_1t = torch.softmax(logits, dim=-1)
    x_1 = p_1t.argmax(dim=-1)  # This correctly returns [batch_size, seq_len]
    return x_1, p_1t

Impact

  • Severity: High - Prevents batch inference with temperature > 0
  • Affected Users: Anyone using batch_sample with batch_size > 1 and temperature > 0
  • Functional Impact: Complete failure of batch generation

Environment

  • Model: Efficient-Large-Model/Fast_dLLM_v2_7B
  • Transformers Version: (check with transformers.__version__)
  • PyTorch Version: (check with torch.__version__)
  • Python Version: 3.10

References


Report Date: 2026-01-15
Reported By: Investigation via code analysis and reproduction testing

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions