From 0ccce90802e9f849f23e3c6be6dce37ae5b9c04b Mon Sep 17 00:00:00 2001 From: "Micah Chambers (minerva)" Date: Fri, 21 Nov 2025 13:12:43 -0800 Subject: [PATCH 1/9] feat: enqueue pre-bound tensorrs - add getter for tensor type --- src/ffi/sync/engine.rs | 93 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index dac5461..f1ee3ee 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -32,6 +32,39 @@ unsafe impl Send for Engine {} /// The TensorRT API is thread-safe with regards to all operations on [`Engine`]. unsafe impl Sync for Engine {} +pub enum DataType { + /// 32-bit floating point format. + Float, + /// IEEE 16-bit floating-point format – has a 5 bit exponent and 11 bit significand. + Half, + /// Signed 8-bit integer representing a quantized floating-point value. + Int8, + /// Signed 32-bit integer format. + Int32, + /// 8-bit boolean. 0 = false, 1 = true, other values undefined. + Bool, + /// Unsigned 8-bit integer format. Cannot be used to represent quantized floating-point values. + /// Use the IdentityLayer to convert kUINT8 network-level inputs to {kFLOAT, kHALF} prior to + /// use with other TensorRT layers, or to convert intermediate output before kUINT8 + /// network-level outputs from {kFLOAT, kHALF} to kUINT8. kUINT8 conversions are only supported + /// for {kFLOAT, kHALF}. kUINT8 to {kFLOAT, kHALF} conversion will convert the integer values + /// to equivalent floating point values. {kFLOAT, kHALF} to kUINT8 conversion will convert the + /// floating point values to integer values by truncating towards zero. This conversion has + /// undefined behavior for floating point values outside the range [0.0F, 256.0F) after + /// truncation. kUINT8 conversions are not supported for {kINT8, kINT32, kBOOL}. + Uint8, + /// Signed 8-bit floating point with 1 sign bit, 4 exponent bits, 3 mantissa bits, and exponent-bias 7. + Fp8, + /// Brain float – has an 8 bit exponent and 8 bit significand. + Bf16, + ///Signed 64-bit integer type. + Int64, + /// Signed 4-bit integer type. + Int4, + /// 4-bit floating point type 1 bit sign, 2 bit exponent, 1 bit mantissa + Fp4, +} + impl Engine { #[inline] pub(crate) fn wrap(internal: *mut std::ffi::c_void, runtime: Runtime) -> Self { @@ -58,6 +91,40 @@ impl Engine { num_io_tensors as usize } + pub fn io_tensor_type(&self, io_tensor_index: usize) -> DataType { + let internal = self.as_ptr(); + let io_tensor_index = io_tensor_index as std::os::raw::c_int; + let data_type: i32 = cpp!(unsafe [ + internal as "const void*", + io_tensor_index as "int" + ] -> i32 as "DataType" { + #if NV_TENSORRT_MAJOR >= 10 + const char* name = ((const ICudaEngine*) internal)->getIOTensorName(io_tensor_index); + if(name == nullptr) { + return DataType::Float; + } + return ((const ICudaEngine*) internal)->getTensorDataType(tensor_name_ptr); + #else + return ((const ICudaEngine*) internal)->getBindingDataType(io_tensor_index); + #endif + }); + + match data_type { + 0 => DataType::Float, + 1 => DataType::Half, + 2 => DataType::Int8, + 3 => DataType::Int32, + 4 => DataType:: Bool, + 5 => DataType::Uint8, + 6 => DataType::Fp8, + 7 => DataType::Bf16, + 8 => DataType::Int64, + 9 => DataType::Int4, + 10 => DataType::Fp4, + _ => panic!("Unknown data type ({data_type}), you might be using an unsupported version of TensorRT") + } + } + pub fn io_tensor_name(&self, io_tensor_index: usize) -> String { let internal = self.as_ptr(); let io_tensor_index = io_tensor_index as std::os::raw::c_int; @@ -226,6 +293,32 @@ impl<'engine> ExecutionContext<'engine> { ) } + pub fn bind( + &mut self, + tensor_name: &str, + buffer: &mut async_cuda::ffi::memory::DeviceBuffer, + ) -> Result<()> { + Ok(unsafe { self.set_tensor_address(tensor_name, buffer) }?) + } + + /// Enqueue with pre-bound + /// this allows for assorted types of inputs + pub fn enqueue_prebound(&mut self, stream: &async_cuda::ffi::stream::Stream) -> Result<()> { + let internal = self.as_mut_ptr(); + let stream_ptr = stream.as_internal().as_ptr(); + let success = cpp!(unsafe [ + internal as "void*", + stream_ptr as "const void*" + ] -> bool as "bool" { + return ((IExecutionContext*) internal)->enqueueV3((cudaStream_t) stream_ptr); + }); + if success { + Ok(()) + } else { + Err(last_error()) + } + } + pub fn enqueue( &mut self, io_tensors: &mut std::collections::HashMap< From 8f4d4d570c3a8d266deec5810761344f105d07b0 Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Fri, 21 Nov 2025 18:55:00 -0800 Subject: [PATCH 2/9] clone, copy, debug --- src/ffi/sync/engine.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index f1ee3ee..ade7018 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -32,6 +32,7 @@ unsafe impl Send for Engine {} /// The TensorRT API is thread-safe with regards to all operations on [`Engine`]. unsafe impl Sync for Engine {} +#[derive(Copy, Clone, Debug)] pub enum DataType { /// 32-bit floating point format. Float, From 06d4e52e85cf9ce47dcf28ee2c563597f2ab1fad Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Fri, 21 Nov 2025 23:04:22 -0800 Subject: [PATCH 3/9] bind --- Cargo.lock | 3 +-- Cargo.toml | 4 ++++ src/ffi/sync/engine.rs | 12 ++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b0ebaa..6e6db66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,8 +29,7 @@ dependencies = [ [[package]] name = "async-cuda" version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56bf487caab780f706b84b5714aa01c27996429d0d0e1538617582038dd0526c" +source = "git+https://github.com/micahcc/async-cuda.git?branch=main#80bf68771b6da5a2e1c97cf6045668348dc9be76" dependencies = [ "cpp", "cpp_build", diff --git a/Cargo.toml b/Cargo.toml index bb442b3..2ca3a58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,10 @@ async-cuda = "0.6.1" cpp = "0.5" tracing = "0.1" +[patch.crates-io] +# need to upstream cuda get devices +async-cuda = { git = "https://github.com/micahcc/async-cuda.git", branch = "main"} + [dev-dependencies] tempfile = "3.4" tokio = { version = "1", default-features = false, features = [ diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index ade7018..f190cd4 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -2,6 +2,7 @@ use cpp::cpp; use async_cuda::device::DeviceId; use async_cuda::ffi::device::Device; +use async_cuda::ffi::ptr::DevicePtr; use crate::error::last_error; use crate::ffi::memory::HostBuffer; @@ -294,12 +295,12 @@ impl<'engine> ExecutionContext<'engine> { ) } - pub fn bind( + pub fn bind_tensor( &mut self, tensor_name: &str, - buffer: &mut async_cuda::ffi::memory::DeviceBuffer, + buffer: &mut async_cuda::ffi::memory::DeviceTensor, ) -> Result<()> { - Ok(unsafe { self.set_tensor_address(tensor_name, buffer) }?) + Ok(unsafe { self.set_tensor_address::(tensor_name, buffer.as_mut_internal()) }?) } /// Enqueue with pre-bound @@ -331,7 +332,7 @@ impl<'engine> ExecutionContext<'engine> { let internal = self.as_mut_ptr(); for (tensor_name, buffer) in io_tensors { unsafe { - self.set_tensor_address(tensor_name, buffer)?; + self.set_tensor_address::(tensor_name, buffer.as_mut_internal())?; } } let stream_ptr = stream.as_internal().as_ptr(); @@ -379,12 +380,11 @@ impl<'engine> ExecutionContext<'engine> { unsafe fn set_tensor_address( &mut self, tensor_name: &str, - buffer: &mut async_cuda::ffi::memory::DeviceBuffer, + buffer_ptr: &mut DevicePtr, ) -> Result<()> { let internal = self.as_mut_ptr(); let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap(); let tensor_name_ptr = tensor_name_cstr.as_ptr(); - let buffer_ptr = buffer.as_mut_internal().as_mut_ptr(); let success = cpp!(unsafe [ internal as "const void*", tensor_name_ptr as "const char*", From 773fe4ce21f0fff37ea2715b456f0878c360df92 Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Sat, 22 Nov 2025 08:16:22 -0800 Subject: [PATCH 4/9] update for trt8 --- src/ffi/sync/engine.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index f190cd4..985ab2b 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -100,13 +100,15 @@ impl Engine { internal as "const void*", io_tensor_index as "int" ] -> i32 as "DataType" { - #if NV_TENSORRT_MAJOR >= 10 + // Added in TRT 8 + #if NV_TENSORRT_MAJOR >= 8 const char* name = ((const ICudaEngine*) internal)->getIOTensorName(io_tensor_index); if(name == nullptr) { - return DataType::Float; + return DataType::kFLOAT; } - return ((const ICudaEngine*) internal)->getTensorDataType(tensor_name_ptr); + return ((const ICudaEngine*) internal)->getTensorDataType(name); #else + // Removed in TRT 10 return ((const ICudaEngine*) internal)->getBindingDataType(io_tensor_index); #endif }); From 305947cdbba189a0b3e6ae6cbe71f4afeff5240a Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Sat, 22 Nov 2025 16:02:08 -0800 Subject: [PATCH 5/9] try --- src/ffi/pre/includes.rs | 1 + src/ffi/sync/engine.rs | 58 ++++++++++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/src/ffi/pre/includes.rs b/src/ffi/pre/includes.rs index a62b609..3b00063 100644 --- a/src/ffi/pre/includes.rs +++ b/src/ffi/pre/includes.rs @@ -3,6 +3,7 @@ use cpp::cpp; cpp! {{ #include #include + #include }} cpp! {{ diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index 985ab2b..56678cf 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -1,13 +1,12 @@ -use cpp::cpp; - -use async_cuda::device::DeviceId; -use async_cuda::ffi::device::Device; -use async_cuda::ffi::ptr::DevicePtr; - use crate::error::last_error; use crate::ffi::memory::HostBuffer; use crate::ffi::result; use crate::ffi::sync::runtime::Runtime; +use async_cuda::device::DeviceId; +use async_cuda::ffi::device::Device; +use async_cuda::ffi::ptr::DevicePtr; +use cpp::cpp; +use std::sync::Arc; type Result = std::result::Result; @@ -70,6 +69,7 @@ pub enum DataType { impl Engine { #[inline] pub(crate) fn wrap(internal: *mut std::ffi::c_void, runtime: Runtime) -> Self { + eprintln!("Engine address: {internal:?}"); Engine { internal, runtime } } @@ -211,6 +211,7 @@ impl Engine { impl Drop for Engine { fn drop(&mut self) { + eprintln!("Dropping Engine"); Device::set_or_panic(self.runtime.device()); let Engine { internal, .. } = *self; cpp!(unsafe [ @@ -246,8 +247,8 @@ unsafe impl<'engine> Send for ExecutionContext<'engine> {} unsafe impl<'engine> Sync for ExecutionContext<'engine> {} impl ExecutionContext<'static> { - pub fn from_engine(mut engine: Engine) -> Result { - let internal = unsafe { Self::new_internal(&mut engine) }; + pub fn from_engine(engine: Engine) -> Result { + let internal = unsafe { Self::new_internal(&engine) }; result!( internal, Self { @@ -259,10 +260,23 @@ impl ExecutionContext<'static> { ) } - pub fn from_engine_many(mut engine: Engine, num: usize) -> Result> { + pub fn from_shared_engine(engine: Arc) -> Result { + let internal = unsafe { Self::new_internal(&engine) }; + result!( + internal, + Self { + internal, + device: engine.device(), + _parent: Some(engine), + _phantom: Default::default(), + } + ) + } + + pub fn from_engine_many(engine: Engine, num: usize) -> Result> { let mut internals = Vec::with_capacity(num); for _ in 0..num { - internals.push(unsafe { Self::new_internal(&mut engine) }); + internals.push(unsafe { Self::new_internal(&engine) }); } let device = engine.device(); let parent = std::sync::Arc::new(engine); @@ -368,14 +382,17 @@ impl<'engine> ExecutionContext<'engine> { self.device } - unsafe fn new_internal(engine: &mut Engine) -> *mut std::ffi::c_void { + unsafe fn new_internal(engine: &Engine) -> *mut std::ffi::c_void { Device::set_or_panic(engine.device()); - let internal_engine = engine.as_mut_ptr(); + let internal_engine = engine.as_ptr(); let internal = cpp!(unsafe [ internal_engine as "void*" ] -> *mut std::ffi::c_void as "void*" { - return (void*) ((ICudaEngine*) internal_engine)->createExecutionContext(); + void* out = (void*) ((ICudaEngine*) internal_engine)->createExecutionContext(); + fprintf(stderr, "Execution Ptr: %p\n", out); + return out; }); + eprintln!("ExecutionContext address: {internal:?}"); internal } @@ -387,14 +404,24 @@ impl<'engine> ExecutionContext<'engine> { let internal = self.as_mut_ptr(); let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap(); let tensor_name_ptr = tensor_name_cstr.as_ptr(); + let buffer_ptr = buffer_ptr.as_ptr(); + eprintln!("buffer: {buffer_ptr:?}"); let success = cpp!(unsafe [ - internal as "const void*", + internal as "void*", tensor_name_ptr as "const char*", buffer_ptr as "void*" ] -> bool as "bool" { + fprintf(stderr, "Engine: %p\n", internal); + fprintf(stderr, "Getting tensor name: %s\n", tensor_name_ptr); + const void* prevTensorAddr = ((IExecutionContext*) internal)->getTensorAddress( + tensor_name_ptr); + fprintf(stderr, "Prev addr: %p", prevTensorAddr); + + fprintf(stderr, "Setting tensor name: %s, buffer ptr: %p, execution: %p\n", tensor_name_ptr, buffer_ptr, internal); return ((IExecutionContext*) internal)->setTensorAddress( tensor_name_ptr, - buffer_ptr + 0 + //buffer_ptr ); }); if success { @@ -407,6 +434,7 @@ impl<'engine> ExecutionContext<'engine> { impl<'engine> Drop for ExecutionContext<'engine> { fn drop(&mut self) { + eprintln!("Dropping ExecutionContext"); Device::set_or_panic(self.device); let ExecutionContext { internal, .. } = *self; cpp!(unsafe [ From 55ffd39669e767aac68591410a7c4a7599e7d8ea Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Sat, 22 Nov 2025 16:06:10 -0800 Subject: [PATCH 6/9] better --- src/ffi/sync/engine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index 56678cf..4330901 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -298,7 +298,7 @@ impl ExecutionContext<'static> { } impl<'engine> ExecutionContext<'engine> { - pub fn new(engine: &'engine mut Engine) -> Result { + pub fn new(engine: &'engine Engine) -> Result { let internal = unsafe { Self::new_internal(engine) }; result!( internal, From 25e3240b69e4e47f708afc8f44836c349711c3d9 Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Sat, 22 Nov 2025 16:07:48 -0800 Subject: [PATCH 7/9] cleanup --- src/ffi/pre/includes.rs | 1 - src/ffi/sync/engine.rs | 3 --- 2 files changed, 4 deletions(-) diff --git a/src/ffi/pre/includes.rs b/src/ffi/pre/includes.rs index 3b00063..a62b609 100644 --- a/src/ffi/pre/includes.rs +++ b/src/ffi/pre/includes.rs @@ -3,7 +3,6 @@ use cpp::cpp; cpp! {{ #include #include - #include }} cpp! {{ diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index 4330901..35dca38 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -69,7 +69,6 @@ pub enum DataType { impl Engine { #[inline] pub(crate) fn wrap(internal: *mut std::ffi::c_void, runtime: Runtime) -> Self { - eprintln!("Engine address: {internal:?}"); Engine { internal, runtime } } @@ -211,7 +210,6 @@ impl Engine { impl Drop for Engine { fn drop(&mut self) { - eprintln!("Dropping Engine"); Device::set_or_panic(self.runtime.device()); let Engine { internal, .. } = *self; cpp!(unsafe [ @@ -434,7 +432,6 @@ impl<'engine> ExecutionContext<'engine> { impl<'engine> Drop for ExecutionContext<'engine> { fn drop(&mut self) { - eprintln!("Dropping ExecutionContext"); Device::set_or_panic(self.device); let ExecutionContext { internal, .. } = *self; cpp!(unsafe [ From f0342601d1a689d22892a7c8178ec20ff7cbd26b Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Sat, 22 Nov 2025 16:08:39 -0800 Subject: [PATCH 8/9] cleanup --- src/ffi/sync/engine.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index 35dca38..3be4bcf 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -387,10 +387,8 @@ impl<'engine> ExecutionContext<'engine> { internal_engine as "void*" ] -> *mut std::ffi::c_void as "void*" { void* out = (void*) ((ICudaEngine*) internal_engine)->createExecutionContext(); - fprintf(stderr, "Execution Ptr: %p\n", out); return out; }); - eprintln!("ExecutionContext address: {internal:?}"); internal } @@ -403,19 +401,13 @@ impl<'engine> ExecutionContext<'engine> { let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap(); let tensor_name_ptr = tensor_name_cstr.as_ptr(); let buffer_ptr = buffer_ptr.as_ptr(); - eprintln!("buffer: {buffer_ptr:?}"); let success = cpp!(unsafe [ internal as "void*", tensor_name_ptr as "const char*", buffer_ptr as "void*" ] -> bool as "bool" { - fprintf(stderr, "Engine: %p\n", internal); - fprintf(stderr, "Getting tensor name: %s\n", tensor_name_ptr); const void* prevTensorAddr = ((IExecutionContext*) internal)->getTensorAddress( tensor_name_ptr); - fprintf(stderr, "Prev addr: %p", prevTensorAddr); - - fprintf(stderr, "Setting tensor name: %s, buffer ptr: %p, execution: %p\n", tensor_name_ptr, buffer_ptr, internal); return ((IExecutionContext*) internal)->setTensorAddress( tensor_name_ptr, 0 From a6b344669da5a617dc1ccee209e754615de53d23 Mon Sep 17 00:00:00 2001 From: "Micah Chambers (eos)" Date: Sat, 22 Nov 2025 16:09:36 -0800 Subject: [PATCH 9/9] undo --- src/ffi/sync/engine.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index 3be4bcf..a4382df 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -406,12 +406,9 @@ impl<'engine> ExecutionContext<'engine> { tensor_name_ptr as "const char*", buffer_ptr as "void*" ] -> bool as "bool" { - const void* prevTensorAddr = ((IExecutionContext*) internal)->getTensorAddress( - tensor_name_ptr); return ((IExecutionContext*) internal)->setTensorAddress( tensor_name_ptr, - 0 - //buffer_ptr + buffer_ptr ); }); if success {