Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ install(
)

# CUDA backend implementation
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp
runtime/cuda_mutable_state.cpp
)
if(_cuda_is_msvc_toolchain)
# MSVC links aoti_cuda_backend into portable_lib without relying on C++
# symbols exported from aoti_cuda_shims.dll.
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ runtime.cxx_library(
name = "cuda_backend",
srcs = [
"cuda_backend.cpp",
"cuda_mutable_state.cpp",
],
headers = [
"cuda_delegate_handle.h",
"cuda_mutable_state.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
Expand Down
11 changes: 11 additions & 0 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/cuda_allocator.h>
#include <executorch/backends/cuda/runtime/cuda_delegate_handle.h>
#include <executorch/backends/cuda/runtime/cuda_mutable_state.h>
#include <executorch/backends/cuda/runtime/platform/platform.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>
Expand Down Expand Up @@ -466,6 +467,10 @@ class ET_EXPERIMENTAL CudaBackend final
kCudaGraphWarmupSteps);
}

// Record whether this AOTI build exposes the constant-management symbols
// needed for per-session mutable-buffer rebinding (CUDA V2 multi-session).
mutable_state_note_handle(handle);

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -514,6 +519,12 @@ class ET_EXPERIMENTAL CudaBackend final
static_cast<int>(device_type));
}

// CUDA V2 multi-session: if a logical session is active on this thread,
// rebind this container's mutable constants (KV/conv/recurrent) to the
// session's own GPU buffers before running. No-op for
// single-session/legacy.
ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle));

// ---------------------------------------------------------------
// CUDA graph REPLAY path — skip all tensor setup and just replay
// ---------------------------------------------------------------
Expand Down
Loading
Loading