In the _compute_target_p function within the eagle3.py file, the target_mask is retrieved from t2d based on the argmax result.
We know that the sampling of the target model involves some level of numerical precision uncertainty and randomness. Using the argmax token directly to determine the target_mask might introduce bias. Why not generate the target_mask based on input_ids (labels) instead? the current code is as following:
@torch.compile(dynamic=None)
def _compute_target_p(target, t2d, loss_mask):
target_head = target
target_max_token = target_head.argmax(-1)
target_mask = t2d[target_max_token]
target_mask = target_mask[..., None].int()
position_mask = target_mask * loss_mask
target_head = target_head[..., t2d]
target_head = target_head.float()
target_p = nn.Softmax(dim=2)(target_head)
target_p = target_p.detach()
return target_p, position_mask
New version like this:
@torch.compile(dynamic=None)
def _compute_target_p(target, t2d, loss_mask, input_ids):
labels = torch.cat([
input_ids[:, 1:],
torch.zeros_like(input_ids[:, :1]),
], dim=1)
target_mask = t2d[labels]
target_mask[:, -1] = False
target_mask = target_mask[..., None].int()
position_mask = target_mask * loss_mask
target_head = target
target_head = target_head[..., t2d]
target_head = target_head.float()
target_p = nn.Softmax(dim=2)(target_head)
target_p = target_p.detach()
return target_p, position_mask
In the _compute_target_p function within the eagle3.py file, the target_mask is retrieved from t2d based on the argmax result.
We know that the sampling of the target model involves some level of numerical precision uncertainty and randomness. Using the argmax token directly to determine the target_mask might introduce bias. Why not generate the target_mask based on input_ids (labels) instead? the current code is as following:
New version like this: