Hi, maintainers,
Thank you for the great work. Here is the snippet I’m referring to:
# Process each block
for num_block in range(num_blocks):
current_block_start = input_ids.shape[1] + num_block * block_length
current_block_end = current_block_start + block_length
# update cache
model_output = self(x, attention_mask, tok_idx, use_cache=True)
past_key_values = model_output.past_key_values
logits = model_output.logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
x[:, current_block_start] = x0[:, current_block_start]
In Dream’s cache implementation, why does each global update directly unmask the token at the block start position, instead of using a specific sampling strategy?