Skip to content

Why not use input_ids to compute_target_p? #413

@jameswu2014

Description

@jameswu2014

In the _compute_target_p function within the eagle3.py file, the target_mask is retrieved from t2d based on the argmax result.
We know that the sampling of the target model involves some level of numerical precision uncertainty and randomness. Using the argmax token directly to determine the target_mask might introduce bias. Why not generate the target_mask based on input_ids (labels) instead? the current code is as following:

@torch.compile(dynamic=None)
def _compute_target_p(target, t2d, loss_mask):  
    target_head = target
    target_max_token = target_head.argmax(-1)
    target_mask = t2d[target_max_token]
    target_mask = target_mask[..., None].int()
    position_mask = target_mask * loss_mask
    target_head = target_head[..., t2d]
    target_head = target_head.float()
    target_p = nn.Softmax(dim=2)(target_head)
    target_p = target_p.detach()
    return target_p, position_mask

New version like this:

@torch.compile(dynamic=None)
def _compute_target_p(target, t2d, loss_mask, input_ids):
    labels = torch.cat([
        input_ids[:, 1:],                    
        torch.zeros_like(input_ids[:, :1]), 
    ], dim=1) 
    target_mask = t2d[labels] 
    target_mask[:, -1] = False
    target_mask = target_mask[..., None].int()
    position_mask = target_mask * loss_mask
    
    target_head = target
    target_head = target_head[..., t2d]
    target_head = target_head.float()
    target_p = nn.Softmax(dim=2)(target_head)
    target_p = target_p.detach()
    return target_p, position_mask

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions