Skip to content

Commit 8d2f811

Browse files
committed
Avoid copying output from GPU to CPU
1 parent df626bd commit 8d2f811

File tree

6 files changed

+131
-17
lines changed

6 files changed

+131
-17
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0123293118efb08ac4ffc4fefe9d330201465c93
1+
de4f3c4978b4d36cc0bb8f87c6877a4a040d7ae7

backends/aoti/aoti_delegate_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/evalue.h>
13+
#include <string>
1314

1415
namespace executorch {
1516
namespace backends {
@@ -85,6 +86,7 @@ struct AOTIDelegateHandle {
8586
AOTInductorModelContainerHandle container_handle;
8687
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
8788
// dependency
89+
std::string method_name;
8890

8991
// Function pointers specific to this handle's shared library
9092
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1414
#include <cstdio>
1515

16+
#include <array>
1617
#include <filesystem>
1718
#include <fstream>
19+
#include <mutex>
1820
#include <string>
1921
#include <vector>
2022

@@ -35,20 +37,54 @@ using executorch::runtime::ArrayRef;
3537
using executorch::runtime::Backend;
3638
using executorch::runtime::BackendExecutionContext;
3739
using executorch::runtime::BackendInitContext;
40+
using executorch::runtime::BackendOption;
41+
using executorch::runtime::BackendOptionContext;
3842
using executorch::runtime::CompileSpec;
3943
using executorch::runtime::DelegateHandle;
4044
using executorch::runtime::Error;
4145
using executorch::runtime::EValue;
4246
using executorch::runtime::FreeableBuffer;
47+
using executorch::runtime::kMaxOptionValueLength;
4348
using executorch::runtime::MemoryAllocator;
4449
using executorch::runtime::NamedDataMap;
4550
using executorch::runtime::Result;
4651
using executorch::runtime::Span;
4752
using executorch::runtime::etensor::Tensor;
4853

54+
namespace {
55+
constexpr char kSkipCopyOutputToCpuForMethod[] =
56+
"skip_copy_output_to_cpu_for_method";
57+
}
58+
4959
class ET_EXPERIMENTAL CudaBackend final
5060
: public ::executorch::runtime::BackendInterface {
5161
private:
62+
void set_skip_copy_method(
63+
const std::array<char, kMaxOptionValueLength>& raw) {
64+
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
65+
skip_copy_method_ = std::string(raw.data());
66+
}
67+
68+
std::array<char, kMaxOptionValueLength> get_skip_copy_method_as_option()
69+
const {
70+
std::array<char, kMaxOptionValueLength> out{};
71+
std::string value;
72+
{
73+
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
74+
value = skip_copy_method_;
75+
}
76+
std::snprintf(out.data(), out.size(), "%s", value.c_str());
77+
return out;
78+
}
79+
80+
bool should_skip_copy_for_method(const std::string& method_name) const {
81+
if (method_name.empty()) {
82+
return false;
83+
}
84+
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
85+
return method_name == skip_copy_method_;
86+
}
87+
5288
Error load_function_pointers_into_handle(
5389
void* so_handle,
5490
AOTIDelegateHandle* handle) const {
@@ -91,6 +127,38 @@ class ET_EXPERIMENTAL CudaBackend final
91127
return 1;
92128
}
93129

130+
Error set_option(
131+
ET_UNUSED BackendOptionContext& context,
132+
const executorch::runtime::Span<BackendOption>& backend_options)
133+
override {
134+
for (const auto& option : backend_options) {
135+
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
136+
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
137+
&option.value)) {
138+
set_skip_copy_method(*val);
139+
} else {
140+
ET_LOG(
141+
Error,
142+
"Option %s must be a method name string.",
143+
kSkipCopyOutputToCpuForMethod);
144+
return Error::InvalidArgument;
145+
}
146+
}
147+
}
148+
return Error::Ok;
149+
}
150+
151+
Error get_option(
152+
ET_UNUSED BackendOptionContext& context,
153+
executorch::runtime::Span<BackendOption>& backend_options) override {
154+
for (auto& option : backend_options) {
155+
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
156+
option.value = get_skip_copy_method_as_option();
157+
}
158+
}
159+
return Error::Ok;
160+
}
161+
94162
// Once per loaded binary blob
95163
Result<DelegateHandle*> init(
96164
BackendInitContext& context,
@@ -159,6 +227,7 @@ class ET_EXPERIMENTAL CudaBackend final
159227
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
160228
handle->so_handle = lib_handle;
161229
handle->so_path = so_path.string();
230+
handle->method_name = method_name;
162231

163232
// Load function pointers specific to this handle's shared library
164233
ET_CHECK_OK_OR_RETURN_ERROR(
@@ -224,7 +293,7 @@ class ET_EXPERIMENTAL CudaBackend final
224293

225294
// Process input tensors: ExecuTorch provides CPU tensors, create GPU
226295
// copies
227-
for (int i = 0; i < n_inputs; i++) {
296+
for (size_t i = 0; i < n_inputs; i++) {
228297
// Get tensor dimensions and properties from ExecuTorch CPU tensor
229298
auto cpu_tensor = &(args[i]->toTensor());
230299
auto sizes = cpu_tensor->sizes();
@@ -260,7 +329,7 @@ class ET_EXPERIMENTAL CudaBackend final
260329
}
261330
// Process output tensors: create GPU counterparts for ExecuTorch CPU
262331
// tensors
263-
for (int i = 0; i < n_outputs; i++) {
332+
for (size_t i = 0; i < n_outputs; i++) {
264333
// Get output tensor dimensions from ExecuTorch CPU tensor
265334
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
266335
auto sizes = cpu_output_tensor->sizes();
@@ -303,18 +372,26 @@ class ET_EXPERIMENTAL CudaBackend final
303372
"AOTInductorModelContainerRun failed with error code %d",
304373
error);
305374

306-
// Copy GPU output results back to CPU output tensors
307-
for (int i = 0; i < n_outputs; i++) {
308-
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
309-
// For DYNAMIC_BOUND tensors we try to resize
310-
ET_CHECK_OK_OR_RETURN_ERROR(
311-
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
312-
"Error resizing tensor at output index %d",
313-
i);
314-
ET_CHECK_OK_OR_RETURN_ERROR(
315-
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
316-
"Failed to copy GPU output %d back to CPU",
317-
i);
375+
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);
376+
377+
if (copy_outputs) {
378+
// Copy GPU output results back to CPU output tensors
379+
for (size_t i = 0; i < n_outputs; i++) {
380+
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
381+
// For DYNAMIC_BOUND tensors we try to resize
382+
ET_CHECK_OK_OR_RETURN_ERROR(
383+
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
384+
"Error resizing tensor at output index %d",
385+
i);
386+
ET_CHECK_OK_OR_RETURN_ERROR(
387+
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
388+
"Failed to copy GPU output %d back to CPU",
389+
i);
390+
}
391+
} else {
392+
for (size_t i = 0; i < n_outputs; i++) {
393+
args[i + n_inputs]->toTensor() = *gpu_outputs[i];
394+
}
318395
}
319396

320397
return Error::Ok;
@@ -365,6 +442,10 @@ class ET_EXPERIMENTAL CudaBackend final
365442
delete handle;
366443
clear_all_tensors();
367444
}
445+
446+
private:
447+
mutable std::mutex skip_copy_method_mutex_;
448+
std::string skip_copy_method_;
368449
};
369450

370451
} // namespace executorch::backends::cuda

extension/asr/runner/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ set_target_properties(
3535
extension_asr_runner PROPERTIES POSITION_INDEPENDENT_CODE ON
3636
)
3737

38+
# If the project is configured to build with CUDA support, try to find a CUDA
39+
# runtime (prefer the CUDAToolkit package). If found, expose a compile-time
40+
# macro so sources can conditionally compile CUDA-aware code.
41+
if(EXECUTORCH_BUILD_CUDA)
42+
find_package(CUDAToolkit QUIET)
43+
if(CUDAToolkit_FOUND)
44+
target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE)
45+
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner")
46+
else()
47+
message(
48+
STATUS
49+
"CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found"
50+
)
51+
endif()
52+
endif()
53+
3854
install(
3955
TARGETS extension_asr_runner
4056
EXPORT ExecuTorchTargets

extension/asr/runner/runner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,22 @@ Error AsrRunner::load() {
107107

108108
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName));
109109
decoder_method_loaded_ = true;
110-
110+
#ifdef CUDA_AVAILABLE
111+
executorch::runtime::BackendOptions<1> backend_options;
112+
// For decoder still copy output from GPU to CPU for sampling.
113+
// TODO: change sampler to use a CUDA kernel to sample and then skip copying
114+
// decoder output as well
115+
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
116+
"skip_copy_output_to_cpu_for_method", kEncoderMethodName));
117+
const auto opt_err =
118+
executorch::runtime::set_option("CudaBackend", backend_options.view());
119+
if (opt_err != ::executorch::runtime::Error::Ok) {
120+
ET_LOG(
121+
Warning,
122+
"Failed to set CUDA backend options: %d",
123+
static_cast<int>(opt_err));
124+
}
125+
#endif
111126
ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer());
112127
auto eos_ids = get_eos_ids(tokenizer_.get(), module_.get());
113128
if (!eos_ids.empty()) {

requirements-examples.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now
44
timm == 1.0.7
55
torchsr == 1.0.4
66
torchtune >= 0.6.1
7-
transformers == 4.56.1
7+
transformers == 5.0.0rc1

0 commit comments

Comments
 (0)