From 88c61bec5468b31d2e46fc8d2a26ebb659ee6ca7 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 12 Dec 2025 15:33:22 -0800 Subject: [PATCH 1/2] Avoid copying output from GPU to CPU --- .../ci_commit_pins/optimum-executorch.txt | 2 +- .ci/scripts/test_huggingface_optimum_model.py | 34 +++++- backends/aoti/aoti_delegate_handle.h | 2 + backends/cuda/runtime/cuda_backend.cpp | 110 +++++++++++++++--- extension/asr/runner/CMakeLists.txt | 16 +++ extension/asr/runner/runner.cpp | 17 ++- requirements-examples.txt | 2 +- 7 files changed, 164 insertions(+), 19 deletions(-) diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index 156ff2f3c82..2aea6eef8d6 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -0123293118efb08ac4ffc4fefe9d330201465c93 +de4f3c4978b4d36cc0bb8f87c6877a4a040d7ae7 diff --git a/.ci/scripts/test_huggingface_optimum_model.py b/.ci/scripts/test_huggingface_optimum_model.py index e5d815cfc00..bebe3d5dd34 100644 --- a/.ci/scripts/test_huggingface_optimum_model.py +++ b/.ci/scripts/test_huggingface_optimum_model.py @@ -170,6 +170,35 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only assert check_causal_lm_output_quality(model_id, generated_tokens) is True +def get_tokenizer_path(model_dir: str, saved_files: tuple) -> str: + """ + Determine the tokenizer path based on files saved by tokenizer.save_pretrained(). + + Args: + model_dir: The directory where tokenizer files were saved + saved_files: Tuple of file paths returned by tokenizer.save_pretrained() + + Returns: + The path to use for loading the tokenizer (either a specific file or directory) + + Raises: + ValueError: If no supported tokenizer file format is found + """ + saved_filenames = {Path(f).name for f in saved_files} + + if "tokenizer.model" in saved_filenames: + return f"{model_dir}/tokenizer.model" + + if "tokenizer.json" in saved_filenames: + return model_dir + + # No supported tokenizer format found + raise ValueError( + f"Unsupported tokenizer format. Expected 'tokenizer.model' (SentencePiece) " + f"or 'tokenizer.json' (HuggingFace) but found: {saved_filenames}" + ) + + def test_llm_with_image_modality( model_id, model_dir, recipe, *, quantize=True, run_only=False ): @@ -196,7 +225,8 @@ def test_llm_with_image_modality( cli_export(command, model_dir) tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.save_pretrained(model_dir) + saved_files = tokenizer.save_pretrained(model_dir) + tokenizer_path = get_tokenizer_path(model_dir, saved_files) # input processor = AutoProcessor.from_pretrained(model_id) @@ -232,7 +262,7 @@ def test_llm_with_image_modality( from executorch.extension.llm.runner import GenerationConfig, MultimodalRunner - runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model") + runner = MultimodalRunner(f"{model_dir}/model.pte", tokenizer_path) generated_text = runner.generate_text_hf( inputs, GenerationConfig(max_new_tokens=128, temperature=0, echo=False), diff --git a/backends/aoti/aoti_delegate_handle.h b/backends/aoti/aoti_delegate_handle.h index 82ce2521750..b14e02da9ef 100644 --- a/backends/aoti/aoti_delegate_handle.h +++ b/backends/aoti/aoti_delegate_handle.h @@ -10,6 +10,7 @@ #include #include +#include namespace executorch { namespace backends { @@ -85,6 +86,7 @@ struct AOTIDelegateHandle { AOTInductorModelContainerHandle container_handle; void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header // dependency + std::string method_name; // Function pointers specific to this handle's shared library AOTInductorModelContainerCreateWithDeviceFunc create_with_device; diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 0cef859ddfb..cd1c6b96f02 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -8,13 +8,16 @@ #include #include +#include #include #include #include #include +#include #include #include +#include #include #include @@ -35,20 +38,54 @@ using executorch::runtime::ArrayRef; using executorch::runtime::Backend; using executorch::runtime::BackendExecutionContext; using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendOption; +using executorch::runtime::BackendOptionContext; using executorch::runtime::CompileSpec; using executorch::runtime::DelegateHandle; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::FreeableBuffer; +using executorch::runtime::kMaxOptionValueLength; using executorch::runtime::MemoryAllocator; using executorch::runtime::NamedDataMap; using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::etensor::Tensor; +namespace { +constexpr char kSkipCopyOutputToCpuForMethod[] = + "skip_copy_output_to_cpu_for_method"; +} + class ET_EXPERIMENTAL CudaBackend final : public ::executorch::runtime::BackendInterface { private: + void set_skip_copy_method( + const std::array& raw) { + std::lock_guard guard(skip_copy_method_mutex_); + skip_copy_method_ = std::string(raw.data()); + } + + std::array get_skip_copy_method_as_option() + const { + std::array out{}; + std::string value; + { + std::lock_guard guard(skip_copy_method_mutex_); + value = skip_copy_method_; + } + std::snprintf(out.data(), out.size(), "%s", value.c_str()); + return out; + } + + bool should_skip_copy_for_method(const std::string& method_name) const { + if (method_name.empty()) { + return false; + } + std::lock_guard guard(skip_copy_method_mutex_); + return method_name == skip_copy_method_; + } + Error load_function_pointers_into_handle( void* so_handle, AOTIDelegateHandle* handle) const { @@ -91,6 +128,38 @@ class ET_EXPERIMENTAL CudaBackend final return 1; } + Error set_option( + ET_UNUSED BackendOptionContext& context, + const executorch::runtime::Span& backend_options) + override { + for (const auto& option : backend_options) { + if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) { + if (auto* val = std::get_if>( + &option.value)) { + set_skip_copy_method(*val); + } else { + ET_LOG( + Error, + "Option %s must be a method name string.", + kSkipCopyOutputToCpuForMethod); + return Error::InvalidArgument; + } + } + } + return Error::Ok; + } + + Error get_option( + ET_UNUSED BackendOptionContext& context, + executorch::runtime::Span& backend_options) override { + for (auto& option : backend_options) { + if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) { + option.value = get_skip_copy_method_as_option(); + } + } + return Error::Ok; + } + // Once per loaded binary blob Result init( BackendInitContext& context, @@ -159,6 +228,7 @@ class ET_EXPERIMENTAL CudaBackend final AOTIDelegateHandle* handle = new AOTIDelegateHandle(); handle->so_handle = lib_handle; handle->so_path = so_path.string(); + handle->method_name = method_name; // Load function pointers specific to this handle's shared library ET_CHECK_OK_OR_RETURN_ERROR( @@ -224,7 +294,7 @@ class ET_EXPERIMENTAL CudaBackend final // Process input tensors: ExecuTorch provides CPU tensors, create GPU // copies - for (int i = 0; i < n_inputs; i++) { + for (size_t i = 0; i < n_inputs; i++) { // Get tensor dimensions and properties from ExecuTorch CPU tensor auto cpu_tensor = &(args[i]->toTensor()); auto sizes = cpu_tensor->sizes(); @@ -260,7 +330,7 @@ class ET_EXPERIMENTAL CudaBackend final } // Process output tensors: create GPU counterparts for ExecuTorch CPU // tensors - for (int i = 0; i < n_outputs; i++) { + for (size_t i = 0; i < n_outputs; i++) { // Get output tensor dimensions from ExecuTorch CPU tensor auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); auto sizes = cpu_output_tensor->sizes(); @@ -303,18 +373,26 @@ class ET_EXPERIMENTAL CudaBackend final "AOTInductorModelContainerRun failed with error code %d", error); - // Copy GPU output results back to CPU output tensors - for (int i = 0; i < n_outputs; i++) { - auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); - // For DYNAMIC_BOUND tensors we try to resize - ET_CHECK_OK_OR_RETURN_ERROR( - resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), - "Error resizing tensor at output index %d", - i); - ET_CHECK_OK_OR_RETURN_ERROR( - aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), - "Failed to copy GPU output %d back to CPU", - i); + const bool copy_outputs = !should_skip_copy_for_method(handle->method_name); + + if (copy_outputs) { + // Copy GPU output results back to CPU output tensors + for (size_t i = 0; i < n_outputs; i++) { + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + // For DYNAMIC_BOUND tensors we try to resize + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), + "Error resizing tensor at output index %d", + i); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), + "Failed to copy GPU output %d back to CPU", + i); + } + } else { + for (size_t i = 0; i < n_outputs; i++) { + args[i + n_inputs]->toTensor() = *gpu_outputs[i]; + } } return Error::Ok; @@ -365,6 +443,10 @@ class ET_EXPERIMENTAL CudaBackend final delete handle; clear_all_tensors(); } + + private: + mutable std::mutex skip_copy_method_mutex_; + std::string skip_copy_method_; }; } // namespace executorch::backends::cuda diff --git a/extension/asr/runner/CMakeLists.txt b/extension/asr/runner/CMakeLists.txt index cc9ba01596a..c3d77712017 100644 --- a/extension/asr/runner/CMakeLists.txt +++ b/extension/asr/runner/CMakeLists.txt @@ -35,6 +35,22 @@ set_target_properties( extension_asr_runner PROPERTIES POSITION_INDEPENDENT_CODE ON ) +# If the project is configured to build with CUDA support, try to find a CUDA +# runtime (prefer the CUDAToolkit package). If found, expose a compile-time +# macro so sources can conditionally compile CUDA-aware code. +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit QUIET) + if(CUDAToolkit_FOUND) + target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE) + message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner") + else() + message( + STATUS + "CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found" + ) + endif() +endif() + install( TARGETS extension_asr_runner EXPORT ExecuTorchTargets diff --git a/extension/asr/runner/runner.cpp b/extension/asr/runner/runner.cpp index 4f2523989c1..61eb7e0366f 100644 --- a/extension/asr/runner/runner.cpp +++ b/extension/asr/runner/runner.cpp @@ -107,7 +107,22 @@ Error AsrRunner::load() { ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName)); decoder_method_loaded_ = true; - +#ifdef CUDA_AVAILABLE + executorch::runtime::BackendOptions<1> backend_options; + // For decoder still copy output from GPU to CPU for sampling. + // TODO: change sampler to use a CUDA kernel to sample and then skip copying + // decoder output as well + ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option( + "skip_copy_output_to_cpu_for_method", kEncoderMethodName)); + const auto opt_err = + executorch::runtime::set_option("CudaBackend", backend_options.view()); + if (opt_err != ::executorch::runtime::Error::Ok) { + ET_LOG( + Warning, + "Failed to set CUDA backend options: %d", + static_cast(opt_err)); + } +#endif ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer()); auto eos_ids = get_eos_ids(tokenizer_.get(), module_.get()); if (!eos_ids.empty()) { diff --git a/requirements-examples.txt b/requirements-examples.txt index 368159f96e9..415e4101312 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now timm == 1.0.7 torchsr == 1.0.4 torchtune >= 0.6.1 -transformers == 4.56.1 +transformers == 5.0.0rc1 From 2271984a751468f050d56883b3de01386753ccb6 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 17 Dec 2025 01:31:04 -0800 Subject: [PATCH 2/2] Add backend options header to runner.cpp --- extension/asr/runner/runner.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/extension/asr/runner/runner.cpp b/extension/asr/runner/runner.cpp index 61eb7e0366f..85af0ca2d39 100644 --- a/extension/asr/runner/runner.cpp +++ b/extension/asr/runner/runner.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include