Skip to content

Commit c449c6c

Browse files
zhongbozhuksivaman
andauthored
[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (NVIDIA#2134)
use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <zhongboz@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 06a38cc commit c449c6c

1 file changed

Lines changed: 4 additions & 10 deletions

File tree

  • transformer_engine/pytorch/csrc/extensions

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
205205
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
206206
size_t offset, at::ScalarType dtype) -> at::Tensor {
207207
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
208-
// in the case where full buffer is empty because local rank receives no tokens for all the experts
209-
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
210-
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
211-
// as much as possible to avoid CPU overhead
212-
if (buffer->data_ptr<uint8_t>() == nullptr) {
208+
bool is_empty_shape = product(shape) == 0;
209+
if (buffer->data_ptr<uint8_t>() == nullptr || is_empty_shape) {
213210
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
214211
}
215212
return at::from_blob(
@@ -359,11 +356,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
359356
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
360357
size_t offset, at::ScalarType dtype) -> at::Tensor {
361358
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
362-
// in the case where full buffer is empty because local rank receives no tokens for all the experts
363-
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
364-
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
365-
// as much as possible to avoid CPU overhead
366-
if (buffer->data_ptr<uint8_t>() == nullptr) {
359+
bool is_empty_shape = product(shape) == 0;
360+
if (buffer->data_ptr<uint8_t>() == nullptr || is_empty_shape) {
367361
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
368362
}
369363
return at::from_blob(

0 commit comments

Comments
 (0)