diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index dac5461..f1ee3ee 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -32,6 +32,39 @@ unsafe impl Send for Engine {} /// The TensorRT API is thread-safe with regards to all operations on [`Engine`]. unsafe impl Sync for Engine {} +pub enum DataType { + /// 32-bit floating point format. + Float, + /// IEEE 16-bit floating-point format – has a 5 bit exponent and 11 bit significand. + Half, + /// Signed 8-bit integer representing a quantized floating-point value. + Int8, + /// Signed 32-bit integer format. + Int32, + /// 8-bit boolean. 0 = false, 1 = true, other values undefined. + Bool, + /// Unsigned 8-bit integer format. Cannot be used to represent quantized floating-point values. + /// Use the IdentityLayer to convert kUINT8 network-level inputs to {kFLOAT, kHALF} prior to + /// use with other TensorRT layers, or to convert intermediate output before kUINT8 + /// network-level outputs from {kFLOAT, kHALF} to kUINT8. kUINT8 conversions are only supported + /// for {kFLOAT, kHALF}. kUINT8 to {kFLOAT, kHALF} conversion will convert the integer values + /// to equivalent floating point values. {kFLOAT, kHALF} to kUINT8 conversion will convert the + /// floating point values to integer values by truncating towards zero. This conversion has + /// undefined behavior for floating point values outside the range [0.0F, 256.0F) after + /// truncation. kUINT8 conversions are not supported for {kINT8, kINT32, kBOOL}. + Uint8, + /// Signed 8-bit floating point with 1 sign bit, 4 exponent bits, 3 mantissa bits, and exponent-bias 7. + Fp8, + /// Brain float – has an 8 bit exponent and 8 bit significand. + Bf16, + ///Signed 64-bit integer type. + Int64, + /// Signed 4-bit integer type. + Int4, + /// 4-bit floating point type 1 bit sign, 2 bit exponent, 1 bit mantissa + Fp4, +} + impl Engine { #[inline] pub(crate) fn wrap(internal: *mut std::ffi::c_void, runtime: Runtime) -> Self { @@ -58,6 +91,40 @@ impl Engine { num_io_tensors as usize } + pub fn io_tensor_type(&self, io_tensor_index: usize) -> DataType { + let internal = self.as_ptr(); + let io_tensor_index = io_tensor_index as std::os::raw::c_int; + let data_type: i32 = cpp!(unsafe [ + internal as "const void*", + io_tensor_index as "int" + ] -> i32 as "DataType" { + #if NV_TENSORRT_MAJOR >= 10 + const char* name = ((const ICudaEngine*) internal)->getIOTensorName(io_tensor_index); + if(name == nullptr) { + return DataType::Float; + } + return ((const ICudaEngine*) internal)->getTensorDataType(tensor_name_ptr); + #else + return ((const ICudaEngine*) internal)->getBindingDataType(io_tensor_index); + #endif + }); + + match data_type { + 0 => DataType::Float, + 1 => DataType::Half, + 2 => DataType::Int8, + 3 => DataType::Int32, + 4 => DataType:: Bool, + 5 => DataType::Uint8, + 6 => DataType::Fp8, + 7 => DataType::Bf16, + 8 => DataType::Int64, + 9 => DataType::Int4, + 10 => DataType::Fp4, + _ => panic!("Unknown data type ({data_type}), you might be using an unsupported version of TensorRT") + } + } + pub fn io_tensor_name(&self, io_tensor_index: usize) -> String { let internal = self.as_ptr(); let io_tensor_index = io_tensor_index as std::os::raw::c_int; @@ -226,6 +293,32 @@ impl<'engine> ExecutionContext<'engine> { ) } + pub fn bind( + &mut self, + tensor_name: &str, + buffer: &mut async_cuda::ffi::memory::DeviceBuffer, + ) -> Result<()> { + Ok(unsafe { self.set_tensor_address(tensor_name, buffer) }?) + } + + /// Enqueue with pre-bound + /// this allows for assorted types of inputs + pub fn enqueue_prebound(&mut self, stream: &async_cuda::ffi::stream::Stream) -> Result<()> { + let internal = self.as_mut_ptr(); + let stream_ptr = stream.as_internal().as_ptr(); + let success = cpp!(unsafe [ + internal as "void*", + stream_ptr as "const void*" + ] -> bool as "bool" { + return ((IExecutionContext*) internal)->enqueueV3((cudaStream_t) stream_ptr); + }); + if success { + Ok(()) + } else { + Err(last_error()) + } + } + pub fn enqueue( &mut self, io_tensors: &mut std::collections::HashMap<