Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/ffi/sync/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,39 @@ unsafe impl Send for Engine {}
/// The TensorRT API is thread-safe with regards to all operations on [`Engine`].
unsafe impl Sync for Engine {}

pub enum DataType {
/// 32-bit floating point format.
Float,
/// IEEE 16-bit floating-point format – has a 5 bit exponent and 11 bit significand.
Half,
/// Signed 8-bit integer representing a quantized floating-point value.
Int8,
/// Signed 32-bit integer format.
Int32,
/// 8-bit boolean. 0 = false, 1 = true, other values undefined.
Bool,
/// Unsigned 8-bit integer format. Cannot be used to represent quantized floating-point values.
/// Use the IdentityLayer to convert kUINT8 network-level inputs to {kFLOAT, kHALF} prior to
/// use with other TensorRT layers, or to convert intermediate output before kUINT8
/// network-level outputs from {kFLOAT, kHALF} to kUINT8. kUINT8 conversions are only supported
/// for {kFLOAT, kHALF}. kUINT8 to {kFLOAT, kHALF} conversion will convert the integer values
/// to equivalent floating point values. {kFLOAT, kHALF} to kUINT8 conversion will convert the
/// floating point values to integer values by truncating towards zero. This conversion has
/// undefined behavior for floating point values outside the range [0.0F, 256.0F) after
/// truncation. kUINT8 conversions are not supported for {kINT8, kINT32, kBOOL}.
Uint8,
/// Signed 8-bit floating point with 1 sign bit, 4 exponent bits, 3 mantissa bits, and exponent-bias 7.
Fp8,
/// Brain float – has an 8 bit exponent and 8 bit significand.
Bf16,
///Signed 64-bit integer type.
Int64,
/// Signed 4-bit integer type.
Int4,
/// 4-bit floating point type 1 bit sign, 2 bit exponent, 1 bit mantissa
Fp4,
}

impl Engine {
#[inline]
pub(crate) fn wrap(internal: *mut std::ffi::c_void, runtime: Runtime) -> Self {
Expand All @@ -58,6 +91,40 @@ impl Engine {
num_io_tensors as usize
}

pub fn io_tensor_type(&self, io_tensor_index: usize) -> DataType {
let internal = self.as_ptr();
let io_tensor_index = io_tensor_index as std::os::raw::c_int;
let data_type: i32 = cpp!(unsafe [
internal as "const void*",
io_tensor_index as "int"
] -> i32 as "DataType" {
#if NV_TENSORRT_MAJOR >= 10
const char* name = ((const ICudaEngine*) internal)->getIOTensorName(io_tensor_index);
if(name == nullptr) {
return DataType::Float;
}
return ((const ICudaEngine*) internal)->getTensorDataType(tensor_name_ptr);
#else
return ((const ICudaEngine*) internal)->getBindingDataType(io_tensor_index);
#endif
});

match data_type {
0 => DataType::Float,
1 => DataType::Half,
2 => DataType::Int8,
3 => DataType::Int32,
4 => DataType:: Bool,
5 => DataType::Uint8,
6 => DataType::Fp8,
7 => DataType::Bf16,
8 => DataType::Int64,
9 => DataType::Int4,
10 => DataType::Fp4,
_ => panic!("Unknown data type ({data_type}), you might be using an unsupported version of TensorRT")
}
}

pub fn io_tensor_name(&self, io_tensor_index: usize) -> String {
let internal = self.as_ptr();
let io_tensor_index = io_tensor_index as std::os::raw::c_int;
Expand Down Expand Up @@ -226,6 +293,32 @@ impl<'engine> ExecutionContext<'engine> {
)
}

pub fn bind<T: Copy>(
&mut self,
tensor_name: &str,
buffer: &mut async_cuda::ffi::memory::DeviceBuffer<T>,
) -> Result<()> {
Ok(unsafe { self.set_tensor_address(tensor_name, buffer) }?)
}

/// Enqueue with pre-bound
/// this allows for assorted types of inputs
pub fn enqueue_prebound(&mut self, stream: &async_cuda::ffi::stream::Stream) -> Result<()> {
let internal = self.as_mut_ptr();
let stream_ptr = stream.as_internal().as_ptr();
let success = cpp!(unsafe [
internal as "void*",
stream_ptr as "const void*"
] -> bool as "bool" {
return ((IExecutionContext*) internal)->enqueueV3((cudaStream_t) stream_ptr);
});
if success {
Ok(())
} else {
Err(last_error())
}
}

pub fn enqueue<T: Copy>(
&mut self,
io_tensors: &mut std::collections::HashMap<
Expand Down