Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions dream/dream_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from transformers import AutoTokenizer
from model.modeling_dream import DreamModel

if __name__ == "__main__":

model_path = "Dream-org/Dream-v0-Instruct-7B"
model = DreamModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to("cuda").eval()

question_1 = "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?"
question_2 = 'Write a story that ends with "Finally, Joey and Rachel get married."'
question_3 = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"
question_4 = "Give me a short introduction to large language model?"
question_5 = "Can you introduce something about Paris?"
question_6 = "Write a code for quicksort. "

messages = [
[{"role": "user", "content": "Answer the question step by step and put the answer in \\boxed\{\}: " + question_1}],
[{"role": "user", "content": question_2}],
[{"role": "user", "content": "Answer the question step by step and put the answer in \\boxed\{\}: " + question_3}],
[{"role": "user", "content": question_4}],
[{"role": "user", "content": question_5}],
[{"role": "user", "content": question_6}]
]

prompts = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
prompt_ids = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
input_ids = prompt_ids.input_ids.to(device="cuda")
attention_mask = prompt_ids.attention_mask.to(device="cuda")

output = model.diffusion_generate(
inputs=input_ids,
attention_mask=attention_mask,
max_new_tokens=128,
output_history=True,
return_dict_in_generate=True,
steps=128,
temperature=0.0,
top_p=0.95,
alg="entropy",
threshold=0.9,
block_length=32,
dual_cache=True,
)

for b in range(len(messages)):
print()
print(f"----Question {b+1}: {messages[b][0]['content']}")
sequence = output.sequences[b]
print(tokenizer.decode(sequence[len(input_ids[0]):]).split('<|endoftext|>')[0])


9 changes: 8 additions & 1 deletion dream/model/generation_utils_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,14 @@ def _sample(
# Prepare attention mask for cached generation
if attention_mask != "full":
# Adjust attention mask for current position
current_attention_mask = attention_mask[:, :, :, current_block_start:]
if dual_cache:
# In dual_cache mode: query is block tokens, key is full sequence
# attention_mask shape: [B, 1, N, N] -> need [B, 1, block_length, N]
current_attention_mask = attention_mask[:, :, current_block_start:current_block_end, :]
else:
# In non-dual-cache mode: query is remaining tokens, key is full sequence
# attention_mask shape: [B, 1, N, N] -> need [B, 1, remaining_length, N]
current_attention_mask = attention_mask[:, :, current_block_start:, :]
else:
current_attention_mask = attention_mask

Expand Down
15 changes: 11 additions & 4 deletions dream/model/modeling_dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from transformers import PretrainedConfig
from .configuration_dream import DreamConfig
from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
from .generation_utils_block import DreamGenerationMixin, DreamGenerationConfig

if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
Expand Down Expand Up @@ -402,10 +402,17 @@ def forward(
if past_key_value is not None:
if dual_cache:
past_key, past_value = past_key_value
replace_indices = replace_position.nonzero(as_tuple=True)[1]
past_key[:, replace_indices] = key_states
replace_indices = replace_position.nonzero(as_tuple=True)[1]

# Handle batched replace_position correctly
B = replace_position.shape[0]
for batch_idx in range(B):
batch_replace_indices = replace_position[batch_idx].nonzero(as_tuple=True)[0]
if len(batch_replace_indices) > 0:
past_key[batch_idx, batch_replace_indices] = key_states[batch_idx, :len(batch_replace_indices)]
past_value[batch_idx, batch_replace_indices] = value_states[batch_idx, :len(batch_replace_indices)]

key_states = past_key
past_value[:, replace_indices] = value_states
value_states = past_value
else:
past_key, past_value = past_key_value
Expand Down