@@ -205,11 +205,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
205205 auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t > &shape,
206206 size_t offset, at::ScalarType dtype) -> at::Tensor {
207207 std::vector<int64_t > shape_int64 (shape.begin (), shape.end ());
208- // in the case where full buffer is empty because local rank receives no tokens for all the experts
209- // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
210- // but in the case where some experts receive tokens, some not, we want to leverage from_blob
211- // as much as possible to avoid CPU overhead
212- if (buffer->data_ptr <uint8_t >() == nullptr ) {
208+ bool is_empty_shape = product (shape) == 0 ;
209+ if (buffer->data_ptr <uint8_t >() == nullptr || is_empty_shape) {
213210 return at::empty (shape_int64, at::device (at::kCUDA ).dtype (dtype));
214211 }
215212 return at::from_blob (
@@ -359,11 +356,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
359356 auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t > &shape,
360357 size_t offset, at::ScalarType dtype) -> at::Tensor {
361358 std::vector<int64_t > shape_int64 (shape.begin (), shape.end ());
362- // in the case where full buffer is empty because local rank receives no tokens for all the experts
363- // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
364- // but in the case where some experts receive tokens, some not, we want to leverage from_blob
365- // as much as possible to avoid CPU overhead
366- if (buffer->data_ptr <uint8_t >() == nullptr ) {
359+ bool is_empty_shape = product (shape) == 0 ;
360+ if (buffer->data_ptr <uint8_t >() == nullptr || is_empty_shape) {
367361 return at::empty (shape_int64, at::device (at::kCUDA ).dtype (dtype));
368362 }
369363 return at::from_blob (
0 commit comments