Skip to content

Batch Inference using Fast-dLLM #65

@Gaurav7888

Description

@Gaurav7888

Hi Authors,
How would you suggest doing batch inference using Fast-dLLM? How we have generate.py for inference for single prompt ?

I have tried using this code, do correct me If I am doing it wrong.

model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

prompts_text = [
    "Krissa needs to order field trip shirts for her preschool students. 11 students need size extra-small. Twice as many students need size small as extra small. Four less than the number of size small students need size medium. Half as many students need size large as size medium. Six more students need size extra-large than large. Altogether, how many shirts did Krissa order?",
    "Jim decides to go to college to earn some more money. It takes him 4 years to finish and he gets $50,000 in loans per year. If he had a 25k a year job before college and his college degree tripled his income, how long would it take to earn the money equivalent to the loans and the money lost from not working while in school.",
    "Russell orders his favorite bagels online. Each pack of bagels costs $10.00 and has 9 bagels in the pack. If he orders 4 packs of bagels, he will receive a 10% discount. After ordering 4 bags, how much will each single bagel cost?",
    "Brendan has a bag of marbles with 10 inside. He tripped over a pebble while carrying it and dropped half of them. He scrambled to search for them but only came up with 3. When he went back home, he inspected the marbles further. One of them he picked up wasn't a marble, but actually a bead so he got rid of it. How many marbles did Brendan end up with?",
    "The great dragon, Perg, sat high atop mount Farbo, breathing fire upon anything within a distance of 1000 feet. Polly could throw the gold javelin, the only known weapon that could sleigh the dragon, for a distance of 400 feet, well within the reach of the dragon's flames. But when Polly held the sapphire gemstone, she could throw the javelin three times farther than when not holding the gemstone. If holding the gemstone, how far outside of the reach of the dragon's flames could Polly stand and still hit the dragon with the gold javelin?",
    "James is in charge of running messages from the office to each teacher's classroom. If he delivers 66 messages to Ms. Thompson and 1/3 as many messages to Mr. Yu, how many messages does he deliver total?",
    "Mariah's grandma was teaching her to knit. Mariah used 1/4 of a skein of yarn. Her grandma used 1/2 of a skein of yarn. There are 364 yards in a skein of yarn. How many yards of yarn did they use altogether?",
    "Tim gets a promotion that offers him a 5% raise on his $20000 a month salary. It also gives him a bonus worth half a month's salary. How much money will he make in a year?",
    "Gissela, Gordy, and Gary are truck drivers. Gissela has a truck large enough to haul 4,000 pounds of gravel. Gordy's truck can haul 800 pounds more than Gissela's truck. And when Gary brings his truck and joins Gissela and Gordy, the three trucks combined can haul a total of 11,600 pounds of gravel. How many pounds of gravel can Gary's truck carry?",
    "Marcell and Beatrice are having a contest to see who can eat the most fruit roll-ups, so they unroll as many as they can find. Unfortunately, someone makes a mistake and Beatrice's was two roll-ups wide and 24 rolls up long while Marcell's was 3 roll-ups wide and 14 roll-ups long. If they both ate their entire amount, how many did they eat on average?"
]

# Tokenize with LEFT padding for batch processing
batch_size = len(prompts_text)
input_ids_list = []

for prompt in prompts_text:
    m = [{"role": "user", "content": prompt}]
    formatted = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
    ids = tokenizer(formatted)['input_ids']
    input_ids_list.append(torch.tensor(ids))

# Pad to max length (left padding)
max_len = max(len(ids) for ids in input_ids_list)
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

input_ids = torch.full((batch_size, max_len), pad_token_id, dtype=torch.long, device=device)

for i, ids in enumerate(input_ids_list):
    start_pos = max_len - len(ids)
    input_ids[i, start_pos:] = ids

print(f"\nBatch size: {batch_size}")
print(f"Prompt length (padded): {max_len}")
print(f"Generation length: 256")
print("="*60)

# Run batch generation
with torch.inference_mode():
    torch.cuda.synchronize()
    start_time = time.time()
    nvtx.range_push("BATCH_INFER")

    out, nfe = generate_with_dual_cache(
        model, input_ids,
        steps=256, gen_length=256, block_length=32,
        temperature=0., remasking='low_confidence'
    )

    torch.cuda.synchronize()
    nvtx.range_pop()
    end_time = time.time()

total_time = end_time - start_time
total_tokens = batch_size * 256
tokens_per_second_total = total_tokens / total_time
tokens_per_second_per_seq = 256 / total_time

print(f"\n=== Batch Generation Results ===")
print(f"Total time: {total_time:.2f}s")
print(f"NFE: {nfe}")
print(f"Total throughput: {tokens_per_second_total:.1f} tok/s (all {batch_size} sequences)")
print(f"Per-sequence: {tokens_per_second_per_seq:.1f} tok/s")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions