diff --git a/dream/dream_generate.py b/dream/dream_generate.py new file mode 100644 index 0000000..0722809 --- /dev/null +++ b/dream/dream_generate.py @@ -0,0 +1,57 @@ +import torch +from transformers import AutoTokenizer +from model.modeling_dream import DreamModel + +if __name__ == "__main__": + + model_path = "Dream-org/Dream-v0-Instruct-7B" + model = DreamModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = model.to("cuda").eval() + + question_1 = "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?" + question_2 = 'Write a story that ends with "Finally, Joey and Rachel get married."' + question_3 = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" + question_4 = "Give me a short introduction to large language model?" + question_5 = "Can you introduce something about Paris?" + question_6 = "Write a code for quicksort. " + + messages = [ + [{"role": "user", "content": "Answer the question step by step and put the answer in \\boxed\{\}: " + question_1}], + [{"role": "user", "content": question_2}], + [{"role": "user", "content": "Answer the question step by step and put the answer in \\boxed\{\}: " + question_3}], + [{"role": "user", "content": question_4}], + [{"role": "user", "content": question_5}], + [{"role": "user", "content": question_6}] + ] + + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + prompt_ids = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + input_ids = prompt_ids.input_ids.to(device="cuda") + attention_mask = prompt_ids.attention_mask.to(device="cuda") + + output = model.diffusion_generate( + inputs=input_ids, + attention_mask=attention_mask, + max_new_tokens=128, + output_history=True, + return_dict_in_generate=True, + steps=128, + temperature=0.0, + top_p=0.95, + alg="entropy", + threshold=0.9, + block_length=32, + dual_cache=True, + ) + + for b in range(len(messages)): + print() + print(f"----Question {b+1}: {messages[b][0]['content']}") + sequence = output.sequences[b] + print(tokenizer.decode(sequence[len(input_ids[0]):]).split('<|endoftext|>')[0]) + + \ No newline at end of file diff --git a/dream/model/generation_utils_block.py b/dream/model/generation_utils_block.py index d8ca61f..1804d7b 100644 --- a/dream/model/generation_utils_block.py +++ b/dream/model/generation_utils_block.py @@ -478,7 +478,14 @@ def _sample( # Prepare attention mask for cached generation if attention_mask != "full": # Adjust attention mask for current position - current_attention_mask = attention_mask[:, :, :, current_block_start:] + if dual_cache: + # In dual_cache mode: query is block tokens, key is full sequence + # attention_mask shape: [B, 1, N, N] -> need [B, 1, block_length, N] + current_attention_mask = attention_mask[:, :, current_block_start:current_block_end, :] + else: + # In non-dual-cache mode: query is remaining tokens, key is full sequence + # attention_mask shape: [B, 1, N, N] -> need [B, 1, remaining_length, N] + current_attention_mask = attention_mask[:, :, current_block_start:, :] else: current_attention_mask = attention_mask diff --git a/dream/model/modeling_dream.py b/dream/model/modeling_dream.py index 244f4ce..30fe5e9 100644 --- a/dream/model/modeling_dream.py +++ b/dream/model/modeling_dream.py @@ -42,7 +42,7 @@ ) from transformers import PretrainedConfig from .configuration_dream import DreamConfig -from .generation_utils import DreamGenerationMixin, DreamGenerationConfig +from .generation_utils_block import DreamGenerationMixin, DreamGenerationConfig if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward @@ -402,10 +402,17 @@ def forward( if past_key_value is not None: if dual_cache: past_key, past_value = past_key_value - replace_indices = replace_position.nonzero(as_tuple=True)[1] - past_key[:, replace_indices] = key_states + replace_indices = replace_position.nonzero(as_tuple=True)[1] + + # Handle batched replace_position correctly + B = replace_position.shape[0] + for batch_idx in range(B): + batch_replace_indices = replace_position[batch_idx].nonzero(as_tuple=True)[0] + if len(batch_replace_indices) > 0: + past_key[batch_idx, batch_replace_indices] = key_states[batch_idx, :len(batch_replace_indices)] + past_value[batch_idx, batch_replace_indices] = value_states[batch_idx, :len(batch_replace_indices)] + key_states = past_key - past_value[:, replace_indices] = value_states value_states = past_value else: past_key, past_value = past_key_value