-
Notifications
You must be signed in to change notification settings - Fork 22
Hotfix/fused ce triton #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modify copyright date. The file is also maintained in upstream so avoid unnecessary reformattings |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
| import triton.language as tl | ||
| from torch.utils.cpp_extension import IS_HIP_EXTENSION | ||
|
|
||
|
|
||
| @triton.jit | ||
| def online_softmax_kernel( | ||
| X_ptr, | ||
|
|
@@ -118,7 +117,7 @@ def cross_entropy_kernel( | |
| m_d_X_y_stride: The stride of m/d/X_y tensor. | ||
| rank (int): The rank of this device in the TP group. | ||
| world_size (int): The size of world involved in this distributed loss calculation. | ||
| ignore_idx (int): Tokens to be ignored for loss and gradient calculation. | ||
| ignore_idx (int): Tokens to be ignored for loss and gradient calculation. (default -100) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no default here. |
||
| n_cols (int): The number of columns in the input tensor. | ||
| n_non_ignore (int): The number of non-ignored elements in the batch. | ||
| label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. | ||
|
|
@@ -231,11 +230,13 @@ def cross_entropy_kernel( | |
| else: | ||
| NUM_WARPS = 32 | ||
|
|
||
|
|
||
| @triton.jit | ||
| def element_mul_kernel( | ||
| X_ptr, | ||
| X_stride, | ||
| grad_output_ptr, | ||
| grad_output_stride, | ||
| n_cols, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
|
|
@@ -258,6 +259,7 @@ def element_mul_kernel( | |
| X_ptr += program_id * X_stride | ||
|
|
||
| # Load the gradient output value | ||
| grad_output_ptr += program_id * grad_output_stride | ||
| grad_output = tl.load(grad_output_ptr) | ||
|
|
||
| # Perform the element-wise multiplication | ||
|
|
@@ -279,6 +281,8 @@ def cross_entropy_forward( | |
|
|
||
| B, SQ, V = _input.shape | ||
| n_rows = B * SQ | ||
| valid_token_count = int((target != ignore_idx).sum().item()) | ||
| denom = max(1, valid_token_count) | ||
|
|
||
| assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID." | ||
|
|
||
|
|
@@ -334,25 +338,29 @@ def cross_entropy_forward( | |
| world_size=world_size, | ||
| ignore_idx=ignore_idx, | ||
| n_cols=V, | ||
| n_non_ignore=n_rows, | ||
| n_non_ignore=denom, | ||
| reduce_loss=reduce_loss, | ||
| label_smoothing=label_smoothing, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| num_warps=NUM_WARPS, | ||
| ) | ||
|
|
||
| loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) | ||
| loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom) | ||
|
|
||
| return loss, _input | ||
|
|
||
|
|
||
| def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): | ||
| def cross_entropy_backward( | ||
| _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code interferes and conflicts with upcoming IFU 2.8 |
||
| ): | ||
| """Backward implementation of cross entropy loss kernel""" | ||
|
|
||
| # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time | ||
| if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): | ||
| # Only check torch.equal when not in CUDA graph capturable mode | ||
| if not is_cg_capturable and torch.equal( | ||
| grad_output, torch.tensor(1.0, device=grad_output.device) | ||
| ): | ||
| pass | ||
|
|
||
| else: | ||
| B, SQ, V = _input.shape | ||
| n_rows = B * SQ | ||
|
|
@@ -362,9 +370,10 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): | |
| _input, | ||
| _input.stride(-2), | ||
| grad_output, | ||
| 1 if grad_output.numel() > 1 else 0, | ||
| V, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| num_warps=NUM_WARPS, | ||
| ) | ||
|
|
||
| return _input | ||
| return _input | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is a purpose of this file change?