From 0c2bce4f7feed1a6c76db2a1d26b939222a1e4c6 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 23 Mar 2026 16:57:57 +0000 Subject: [PATCH] initial prototype of dynamic dim selection for models in OnnxRT --- .../migraphx/migraphx_execution_provider.cc | 507 +++++++++++++++--- .../migraphx/migraphx_execution_provider.h | 18 +- .../migraphx_execution_provider_info.cc | 4 + .../migraphx_execution_provider_info.h | 6 + 4 files changed, 467 insertions(+), 68 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9680e1c7d151e..d368482a7c59c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -161,7 +161,9 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv external_free_{info.external_free}, external_empty_cache_{info.external_empty_cache}, max_dynamic_batch_{info.max_dynamic_batch}, - compile_batches_{info.compile_batches} { + compile_batches_{info.compile_batches}, + dynamic_dimension_index_{info.dynamic_dimension_index}, + max_dynamic_dim_size_{info.max_dynamic_dim_size} { InitProviderOrtApi(); // Set GPU device to be used and read device properties for feature usage. @@ -197,6 +199,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); GET_ENV_STRING(migraphx_env_vars::kCompileBatches, compile_batches_); + GET_ENV(migraphx_env_vars::kDynamicDimensionIndex, dynamic_dimension_index_, dynamic_dimension_index_ = std::stoi(dynamic_dimension_index_env)); + GET_ENV(migraphx_env_vars::kMaxDynamicDimSize, max_dynamic_dim_size_, max_dynamic_dim_size_ = std::stoull(max_dynamic_dim_size_env)); // If compile_batches is set, auto-derive max_dynamic_batch from the spec's max value if (!compile_batches_.empty()) { @@ -218,6 +222,28 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } } + // Validate dynamic dimension configuration + if (dynamic_dimension_index_ >= 0 && max_dynamic_dim_size_ == 0) { + LOGS_DEFAULT(WARNING) << "[MIGraphX] dynamic_dimension_index=" << dynamic_dimension_index_ + << " specified but max_dynamic_dim_size=0. Disabling dynamic dimension."; + dynamic_dimension_index_ = -1; + } + if (dynamic_dimension_index_ < 0 && max_dynamic_dim_size_ > 0) { + LOGS_DEFAULT(WARNING) << "[MIGraphX] max_dynamic_dim_size=" << max_dynamic_dim_size_ + << " specified but dynamic_dimension_index<0. Disabling dynamic dimension."; + max_dynamic_dim_size_ = 0; + } + if (dynamic_dimension_index_ == 0) { + LOGS_DEFAULT(WARNING) << "[MIGraphX] dynamic_dimension_index=0 is the batch dimension. " + << "Use migraphx_max_dynamic_batch for batch dimension. Disabling dynamic dimension."; + dynamic_dimension_index_ = -1; + max_dynamic_dim_size_ = 0; + } + if (dynamic_dimension_index_ > 0 && max_dynamic_dim_size_ > 0) { + LOGS_DEFAULT(INFO) << "[MIGraphX] Dynamic dimension enabled: index=" << dynamic_dimension_index_ + << ", max_size=" << max_dynamic_dim_size_; + } + // Verify configuration correctness and adjust accordingly. #if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR < 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH < 2))) @@ -276,7 +302,9 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_ << "\n " << migraphx_provider_option::kModelMaxDynamicBatch << ": " << max_dynamic_batch_ - << "\n " << migraphx_provider_option::kCompileBatches << ": " << (compile_batches_.empty() ? "(not set)" : compile_batches_); + << "\n " << migraphx_provider_option::kCompileBatches << ": " << (compile_batches_.empty() ? "(not set)" : compile_batches_) + << "\n " << migraphx_provider_option::kDynamicDimensionIndex << ": " << dynamic_dimension_index_ + << "\n " << migraphx_provider_option::kMaxDynamicDimSize << ": " << max_dynamic_dim_size_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { @@ -1400,6 +1428,41 @@ static void pad_input_tensor(const void* src_data, void* dst_data, } } +// Pad input tensor data along an arbitrary dimension (not just dim 0) +// For a tensor with shape [d0, d1, ..., d_dim, ..., dN], this pads d_dim from original_size to padded_size +// by replicating the last slice along that dimension. +// outer_elements = product of dims before the target dim +// inner_elements = product of dims after the target dim +static void pad_input_tensor_dim(const void* src_data, void* dst_data, + std::size_t original_size, std::size_t padded_size, + std::size_t element_size_bytes, + std::size_t outer_elements, std::size_t inner_elements, + hipStream_t stream) { + std::size_t inner_bytes = element_size_bytes * inner_elements; + std::size_t src_stride = original_size * inner_bytes; + std::size_t dst_stride = padded_size * inner_bytes; + + for (std::size_t o = 0; o < outer_elements; ++o) { + const char* src_outer = static_cast(src_data) + o * src_stride; + char* dst_outer = static_cast(dst_data) + o * dst_stride; + + // Copy original data for this outer slice + HIP_CALL_THROW(hipMemcpyAsync(dst_outer, src_outer, src_stride, + hipMemcpyDeviceToDevice, stream)); + + // Replicate the last element along the target dim + if (original_size > 0 && padded_size > original_size) { + const char* last_slice = src_outer + (original_size - 1) * inner_bytes; + char* pad_start = dst_outer + original_size * inner_bytes; + for (std::size_t i = original_size; i < padded_size; ++i) { + HIP_CALL_THROW(hipMemcpyAsync(pad_start, last_slice, inner_bytes, + hipMemcpyDeviceToDevice, stream)); + pad_start += inner_bytes; + } + } + } +} + // Allocate padded input buffers and pad the data for dynamic batching // Returns true if padding was applied, false otherwise // OPTIMIZATION: Reuses existing buffers if padded batch size matches @@ -1584,6 +1647,129 @@ static bool allocate_and_pad_inputs( return true; } +// Helper: Get element size in bytes from ONNXTensorElementDataType +static std::size_t get_element_size_bytes(ONNXTensorElementDataType elem_type) { + switch (elem_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return sizeof(float); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return sizeof(uint16_t); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return sizeof(int64_t); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return sizeof(int32_t); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return sizeof(int16_t); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return sizeof(int8_t); + default: + return sizeof(float); + } +} + +// Allocate padded input buffers and pad data along a non-batch dimension for dynamic dim support. +// Works analogously to allocate_and_pad_inputs but operates on the dimension specified by dim_index. +static bool allocate_and_pad_inputs_dim( + MIGraphXFuncState* mgx_state, + Ort::KernelContext& ctx, + int dim_index, + std::size_t original_dim_size, + std::size_t padded_dim_size, + hipStream_t stream) { + + if (padded_dim_size <= original_dim_size || mgx_state->cached_inputs.empty() || dim_index < 0) { + return false; + } + + bool can_reuse_buffers = ( + mgx_state->last_padded_dim_size == padded_dim_size && + !mgx_state->padded_input_buffers.empty() && + mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size() + ); + + if (can_reuse_buffers) { + for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { + const auto& cached_inp = mgx_state->cached_inputs[i]; + auto input_tensor = ctx.GetInput(cached_inp.ort_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); + + if (tensor_shape.empty() || static_cast(tensor_shape.size()) <= dim_index) continue; + + auto& padded_buf = mgx_state->padded_input_buffers[i]; + std::size_t element_size_bytes = get_element_size_bytes(tensor_info.GetElementType()); + + std::size_t outer_elements = 1; + for (int d = 0; d < dim_index; ++d) outer_elements *= tensor_shape[d]; + std::size_t inner_elements = 1; + for (int d = dim_index + 1; d < static_cast(tensor_shape.size()); ++d) inner_elements *= tensor_shape[d]; + + const void* original_data = input_tensor.GetTensorRawData(); + pad_input_tensor_dim(original_data, padded_buf.data, original_dim_size, padded_dim_size, + element_size_bytes, outer_elements, inner_elements, stream); + } + mgx_state->last_original_dim_size = original_dim_size; + return true; + } + + // Free old buffers + for (auto& buf : mgx_state->padded_input_buffers) { + if (buf.data != nullptr) { + HIP_CALL_THROW(hipFree(buf.data)); + buf.data = nullptr; + } + } + mgx_state->padded_input_buffers.clear(); + mgx_state->padded_input_buffers.reserve(mgx_state->cached_inputs.size()); + + for (const auto& cached_inp : mgx_state->cached_inputs) { + auto input_tensor = ctx.GetInput(cached_inp.ort_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); + + if (tensor_shape.empty() || static_cast(tensor_shape.size()) <= dim_index) { + continue; + } + + std::vector padded_lens(tensor_shape.begin(), tensor_shape.end()); + padded_lens[dim_index] = padded_dim_size; + + migraphx::shape padded_mgx_shape{cached_inp.mgx_shape.type(), padded_lens}; + std::size_t padded_bytes = padded_mgx_shape.bytes(); + + void* padded_data = nullptr; + HIP_CALL_THROW(hipMalloc(&padded_data, padded_bytes)); + + std::size_t element_size_bytes = get_element_size_bytes(tensor_info.GetElementType()); + + std::size_t outer_elements = 1; + for (int d = 0; d < dim_index; ++d) outer_elements *= tensor_shape[d]; + std::size_t inner_elements = 1; + for (int d = dim_index + 1; d < static_cast(tensor_shape.size()); ++d) inner_elements *= tensor_shape[d]; + + const void* original_data = input_tensor.GetTensorRawData(); + pad_input_tensor_dim(original_data, padded_data, original_dim_size, padded_dim_size, + element_size_bytes, outer_elements, inner_elements, stream); + + MIGraphXFuncState::PaddedBuffer buf; + buf.data = padded_data; + buf.size_bytes = padded_bytes; + buf.mgx_shape = padded_mgx_shape; + mgx_state->padded_input_buffers.push_back(buf); + } + + mgx_state->last_original_dim_size = original_dim_size; + mgx_state->last_padded_dim_size = padded_dim_size; + return true; +} + // Helper: Extract output index from MIGraphX output parameter name // MIGraphX names outputs as "#output_0", "#output_1", etc. static int compute_output_index(const std::string_view sv) { @@ -1937,6 +2123,7 @@ static migraphx::program load_or_compile_model( // This function executes the compiled MIGraphX program and copies outputs that // were not pre-allocated (input parameters reused as outputs) to the ORT output tensors // If original_batch_size is provided and < padded batch size, slices the output to remove padding +// If dynamic_dim_index >= 0 and original_dim_size < padded_dim_size, also slices the dynamic dimension static void run_migraphx_program( std::mutex* mgx_mu_ptr, const OrtApi* api, @@ -1946,7 +2133,10 @@ static void run_migraphx_program( migraphx::program_parameters& m, const std::vector& prog_output_indices, std::size_t original_batch_size = 0, - std::size_t padded_batch_size = 0) + std::size_t padded_batch_size = 0, + int dynamic_dim_index = -1, + std::size_t original_dim_size = 0, + std::size_t padded_dim_size = 0) { void* rocm_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); @@ -1957,8 +2147,11 @@ static void run_migraphx_program( prog_outputs = prog.run_async(m, static_cast(rocm_stream)); } - bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && - original_batch_size < padded_batch_size); + bool needs_batch_slicing = (original_batch_size > 0 && padded_batch_size > 0 && + original_batch_size < padded_batch_size); + bool needs_dim_slicing = (dynamic_dim_index >= 0 && original_dim_size > 0 && + padded_dim_size > 0 && original_dim_size < padded_dim_size); + bool needs_slicing = needs_batch_slicing || needs_dim_slicing; // Process ALL outputs for proper slicing when needed auto output_num = prog_outputs->size(); @@ -1977,49 +2170,57 @@ static void run_migraphx_program( if (needs_slicing && !prog_output_indices_set.empty()) { for (std::size_t i = 0; i < output_num; ++i) { if (prog_output_indices_set.count(i) > 0) { - // This output was pre-allocated with padded shape - need to copy sliced data auto gpu_res = (*prog_outputs)[i]; migraphx::shape res_shape = gpu_res.get_shape(); auto res_lens = res_shape.lengths(); - // Create sliced shape for ORT output std::vector ort_shape{res_lens.begin(), res_lens.end()}; - if (!ort_shape.empty() && static_cast(ort_shape[0]) != original_batch_size) { + bool shape_changed = false; + + if (needs_batch_slicing && !ort_shape.empty() && + static_cast(ort_shape[0]) != original_batch_size) { ort_shape[0] = static_cast(original_batch_size); - - // Calculate bytes to copy (sliced portion only) - std::size_t bytes_per_batch = res_shape.bytes() / padded_batch_size; - std::size_t bytes_to_copy = bytes_per_batch * original_batch_size; - - // Allocate temp buffer for sliced data on GPU + shape_changed = true; + } + if (needs_dim_slicing && dynamic_dim_index >= 0 && + dynamic_dim_index < static_cast(ort_shape.size())) { + ort_shape[dynamic_dim_index] = static_cast(original_dim_size); + shape_changed = true; + } + + if (shape_changed) { + // Compute sliced total bytes + std::size_t sliced_total_elements = 1; + for (auto d : ort_shape) sliced_total_elements *= static_cast(d); + std::size_t elem_bytes = res_shape.bytes(); + std::size_t padded_total_elements = 1; + for (auto d : res_lens) padded_total_elements *= d; + std::size_t element_byte_size = (padded_total_elements > 0) ? elem_bytes / padded_total_elements : 0; + std::size_t bytes_to_copy = sliced_total_elements * element_byte_size; + void* temp_sliced_buffer = nullptr; auto hip_status = hipMalloc(&temp_sliced_buffer, bytes_to_copy); if (hip_status != hipSuccess) { ORT_THROW("hipMalloc failed for sliced output buffer"); } - - // Copy sliced data from MIGraphX output to temp buffer + + // For batch-only slicing the simple contiguous copy works + // For dim slicing we'd need a strided copy, but this defensive path is rarely hit HIP_CALL_THROW(hipMemcpyWithStream(temp_sliced_buffer, gpu_res.data(), bytes_to_copy, hipMemcpyDeviceToDevice, static_cast(rocm_stream))); - - // Synchronize to ensure copy is complete before allocating ORT output HIP_CALL_THROW(hipStreamSynchronize(static_cast(rocm_stream))); - // Now allocate the ORT output tensor with the SLICED shape auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - // Copy from temp buffer to ORT output HIP_CALL_THROW(hipMemcpyWithStream(output_data, temp_sliced_buffer, bytes_to_copy, hipMemcpyDeviceToDevice, static_cast(rocm_stream))); - - // Free temporary buffer (void)hipFree(temp_sliced_buffer); } } @@ -2033,19 +2234,35 @@ static void run_migraphx_program( migraphx::shape res_shape = gpu_res.get_shape(); auto res_lens = res_shape.lengths(); - // Adjust output shape if slicing is needed std::vector ort_shape{res_lens.begin(), res_lens.end()}; - if (needs_slicing && !ort_shape.empty()) { - ort_shape[0] = original_batch_size; // Slice batch dimension + if (needs_batch_slicing && !ort_shape.empty()) { + ort_shape[0] = original_batch_size; + } + if (needs_dim_slicing && dynamic_dim_index >= 0 && + dynamic_dim_index < static_cast(ort_shape.size())) { + ort_shape[dynamic_dim_index] = static_cast(original_dim_size); } auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - // Calculate bytes to copy (slice if needed) std::size_t bytes_to_copy = res_shape.bytes(); - if (needs_slicing && res_lens.size() > 0) { + if (needs_batch_slicing && !needs_dim_slicing && res_lens.size() > 0) { bytes_to_copy = (res_shape.bytes() / padded_batch_size) * original_batch_size; + } else if (needs_dim_slicing && !needs_batch_slicing && res_lens.size() > 0) { + std::size_t padded_dim_val = (dynamic_dim_index < static_cast(res_lens.size())) + ? res_lens[dynamic_dim_index] : 1; + if (padded_dim_val > 0) { + bytes_to_copy = (res_shape.bytes() / padded_dim_val) * original_dim_size; + } + } else if (needs_batch_slicing && needs_dim_slicing && res_lens.size() > 0) { + // Both batch and dim slicing: compute based on sliced shape + std::size_t sliced_elements = 1; + for (auto d : ort_shape) sliced_elements *= static_cast(d); + std::size_t padded_elements = 1; + for (auto d : res_lens) padded_elements *= d; + std::size_t element_bytes = (padded_elements > 0) ? res_shape.bytes() / padded_elements : 0; + bytes_to_copy = sliced_elements * element_bytes; } HIP_CALL_THROW(hipMemcpyWithStream(output_data, @@ -2075,12 +2292,28 @@ static void handle_input_shape_mismatch( const auto& map_input_name_index = mgx_state->input_name_indexes; // Build cache key from all inputs in map_input_name_index (already filtered to model inputs only) + // When dynamic dim is enabled, replace that dimension with the padded size for hash consistency + const int dyn_dim_idx = mgx_state->dynamic_dimension_index; + const bool has_dyn_dim = mgx_state->has_dynamic_dim; + std::size_t padded_dim_for_compile = 0; + std::vector all_input_shapes; for (const auto& it : map_input_name_index) { auto input_tensor = ctx.GetInput(it.second); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); - all_input_shapes.insert(all_input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); + for (int d = 0; d < static_cast(tensor_shape.size()); ++d) { + if (has_dyn_dim && d == dyn_dim_idx) { + std::size_t runtime_val = static_cast(tensor_shape[d]); + std::size_t padded_val = find_nearest_compiled_batch_size( + runtime_val, mgx_state->compiled_dim_sizes); + if (padded_val == 0) padded_val = mgx_state->max_dynamic_dim_size; + padded_dim_for_compile = padded_val; + all_input_shapes.push_back(static_cast(padded_val)); + } else { + all_input_shapes.push_back(tensor_shape[d]); + } + } } auto cache_hash = make_hash(all_input_shapes); @@ -2091,7 +2324,7 @@ static void handle_input_shape_mismatch( if (it != cached_progs.end()) { prog = it->second; param_shapes = prog.get_parameter_shapes(); - return; // Early exit - no need to load from disk or compile + return; } } @@ -2102,7 +2335,8 @@ static void handle_input_shape_mismatch( } // Set input parameter shapes from runtime tensors before compilation - + // When dynamic dim is active, use the padded dim size so the compiled model + // can accommodate all runtime sizes up to that padded value for (const auto& it : map_input_name_index) { const auto& name = it.first; const auto& index = it.second; @@ -2110,6 +2344,9 @@ static void handle_input_shape_mismatch( auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); std::vector ort_lens(tensor_shape.begin(), tensor_shape.end()); + if (has_dyn_dim && dyn_dim_idx >= 0 && dyn_dim_idx < static_cast(ort_lens.size()) && padded_dim_for_compile > 0) { + ort_lens[dyn_dim_idx] = padded_dim_for_compile; + } cmp_options.set_input_parameter_shape(name, ort_lens); } @@ -2281,13 +2518,17 @@ static void populate_ultra_fast_caches( // Helper: Build input shapes vector in cached_inputs order (MIGraphX parameter order) // This ensures consistency between how shapes are stored and how they're compared in ultra-fast path +// When padded_batch_size > 0, dim 0 is replaced with the padded batch size. +// When padded_dim_size > 0 and dynamic_dim_index > 0, that dimension is replaced with padded_dim_size. static std::vector build_input_shapes_in_cached_order( MIGraphXFuncState* mgx_state, Ort::KernelContext& ctx, - std::size_t padded_batch_size = 0) + std::size_t padded_batch_size = 0, + int dynamic_dim_index = -1, + std::size_t padded_dim_size = 0) { std::vector shapes; - shapes.reserve(mgx_state->cached_inputs.size() * 4); // Estimate average 4 dims per input + shapes.reserve(mgx_state->cached_inputs.size() * 4); for (const auto& cached_inp : mgx_state->cached_inputs) { auto input_tensor = ctx.GetInput(cached_inp.ort_index); @@ -2295,13 +2536,14 @@ static std::vector build_input_shapes_in_cached_order( const auto tensor_shape = tensor_info.GetShape(); if (!tensor_shape.empty()) { - if (padded_batch_size > 0) { - // Use padded batch size for first dimension - shapes.push_back(static_cast(padded_batch_size)); - shapes.insert(shapes.end(), tensor_shape.begin() + 1, tensor_shape.end()); - } else { - // Use original shape - shapes.insert(shapes.end(), tensor_shape.begin(), tensor_shape.end()); + for (int d = 0; d < static_cast(tensor_shape.size()); ++d) { + if (d == 0 && padded_batch_size > 0) { + shapes.push_back(static_cast(padded_batch_size)); + } else if (d == dynamic_dim_index && padded_dim_size > 0) { + shapes.push_back(static_cast(padded_dim_size)); + } else { + shapes.push_back(tensor_shape[d]); + } } } } @@ -2600,8 +2842,31 @@ static bool execute_fast_path( const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); + // Extract dynamic dimension info + std::size_t original_dim_size_fp = 0; + std::size_t padded_dim_size_fp = 0; + bool needs_dim_padding_fp = false; + const int dyn_dim_idx_fp = mgx_state->dynamic_dimension_index; + + if (mgx_state->has_dynamic_dim && !mgx_state->compiled_dim_sizes.empty()) { + for (const auto& [name, index] : map_input_name_index) { + auto input_tensor = ctx.GetInput(index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); + if (dyn_dim_idx_fp >= 0 && dyn_dim_idx_fp < static_cast(tensor_shape.size())) { + original_dim_size_fp = static_cast(tensor_shape[dyn_dim_idx_fp]); + padded_dim_size_fp = find_nearest_compiled_batch_size( + original_dim_size_fp, mgx_state->compiled_dim_sizes); + if (padded_dim_size_fp == 0) padded_dim_size_fp = mgx_state->max_dynamic_dim_size; + needs_dim_padding_fp = (padded_dim_size_fp > original_dim_size_fp); + break; + } + } + } + bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && - original_batch_size < padded_batch_size); + original_batch_size < padded_batch_size) || + needs_dim_padding_fp; // ═══════════════════════════════════════════════════════════════════════════ // OPTIMIZATION 2: Skip populate_ultra_fast_caches when already populated @@ -2612,15 +2877,23 @@ static bool execute_fast_path( mgx_state->ultra_fast_caches_populated = true; } - // Allocate and pad inputs if needed for dynamic batching + // Allocate and pad inputs if needed bool using_padded_inputs = false; - if (padded_batch_size > original_batch_size) { - void* rocm_stream_ptr; + void* rocm_stream_ptr = nullptr; + hipStream_t rocm_stream = nullptr; + if (padded_batch_size > original_batch_size || needs_dim_padding_fp) { Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream_ptr)); - auto rocm_stream = static_cast(rocm_stream_ptr); + rocm_stream = static_cast(rocm_stream_ptr); + } + if (padded_batch_size > original_batch_size) { using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, padded_batch_size, rocm_stream); } + if (needs_dim_padding_fp) { + using_padded_inputs = allocate_and_pad_inputs_dim(mgx_state, ctx, dyn_dim_idx_fp, + original_dim_size_fp, padded_dim_size_fp, + rocm_stream) || using_padded_inputs; + } // ═══════════════════════════════════════════════════════════════════════════ // OPTIMIZATION 3: Reuse temp output buffers when slicing @@ -2628,7 +2901,8 @@ static bool execute_fast_path( std::vector temp_output_buffer_ptrs; if (needs_slicing) { temp_output_buffer_ptrs = get_or_allocate_temp_output_buffers( - mgx_state, param_shapes, output_shapes, map_input_name_index, padded_batch_size); + mgx_state, param_shapes, output_shapes, map_input_name_index, + padded_batch_size > 0 ? padded_batch_size : 1); } // Bind inputs/outputs (use temp buffers for outputs when slicing) @@ -2639,10 +2913,11 @@ static bool execute_fast_path( mgx_state->cached_prog_params = std::move(m); mgx_state->cached_prog_output_indices = std::move(prog_output_indices); - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( - mgx_state, ctx, using_padded_inputs ? padded_batch_size : 0); + mgx_state, ctx, + using_padded_inputs && padded_batch_size > original_batch_size ? padded_batch_size : 0, + needs_dim_padding_fp ? dyn_dim_idx_fp : -1, + needs_dim_padding_fp ? padded_dim_size_fp : 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; @@ -2660,10 +2935,9 @@ static bool execute_fast_path( run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, mgx_state->cached_prog_params.value(), mgx_state->cached_prog_output_indices, - original_batch_size, padded_batch_size); - - // NOTE: Temp output buffers are kept for reuse - they will be freed when batch size changes - // NOTE: Padded input buffers are also kept for reuse + original_batch_size, padded_batch_size, + needs_dim_padding_fp ? dyn_dim_idx_fp : -1, + original_dim_size_fp, padded_dim_size_fp); return true; } @@ -2795,9 +3069,16 @@ static void compile_dynamic_batch_models( }() << "]"; // Store shape without batch dimension + // When dynamic dim is active, use the max dynamic dim size for that dimension std::vector base_shape; if (tensor_shape.size() > 1) { base_shape.assign(tensor_shape.begin() + 1, tensor_shape.end()); + if (mgx_state->has_dynamic_dim && mgx_state->dynamic_dimension_index > 0) { + int base_idx = mgx_state->dynamic_dimension_index - 1; // Offset by 1 since batch is removed + if (base_idx >= 0 && base_idx < static_cast(base_shape.size())) { + base_shape[base_idx] = static_cast(mgx_state->max_dynamic_dim_size); + } + } } all_input_base_shapes.push_back(base_shape); @@ -3035,7 +3316,8 @@ static void execute_standard_path( // Run with slicing enabled run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, - prog_output_indices, original_batch_size, padded_batch_size); + prog_output_indices, original_batch_size, padded_batch_size, + -1, 0, 0); // Free temporary output buffers for (void* buf : temp_output_buffers) { @@ -3070,7 +3352,7 @@ static void execute_standard_path( mgx_state->caches_valid = true; run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, prog_output_indices, - 0, 0); // Pass 0,0 for batch sizes to indicate no slicing needed + 0, 0, -1, 0, 0); return; } @@ -3083,7 +3365,6 @@ static void execute_standard_path( mgx_state->defer_compilation, map_input_name_index, ctx, cmp_options, prog); if (!input_shape_match) { - // Invalidate caches before recompilation mgx_state->caches_valid = false; handle_input_shape_mismatch( @@ -3095,33 +3376,83 @@ static void execute_standard_path( param_shapes, input_shapes); - // Re-fetch param_shapes after recompilation param_shapes = prog.get_parameter_shapes(); } - // Fetch output shapes once auto output_shapes = prog.get_output_shapes(); - // Populate optimized caches for ultra-fast path - populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index); + // Extract dynamic dim info for slicing + std::size_t original_dim_size_sp = 0; + std::size_t padded_dim_size_sp = 0; + bool needs_dim_slicing_sp = false; + const int dyn_dim_sp = mgx_state->dynamic_dimension_index; + + if (mgx_state->has_dynamic_dim && !mgx_state->compiled_dim_sizes.empty()) { + for (const auto& [name, index] : map_input_name_index) { + auto input_tensor = ctx.GetInput(index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); + if (dyn_dim_sp >= 0 && dyn_dim_sp < static_cast(tensor_shape.size())) { + original_dim_size_sp = static_cast(tensor_shape[dyn_dim_sp]); + padded_dim_size_sp = find_nearest_compiled_batch_size( + original_dim_size_sp, mgx_state->compiled_dim_sizes); + if (padded_dim_size_sp == 0) padded_dim_size_sp = mgx_state->max_dynamic_dim_size; + needs_dim_slicing_sp = (padded_dim_size_sp > original_dim_size_sp); + break; + } + } + } + + bool needs_any_slicing = (original_batch_size > 0 && padded_batch_size > 0 && + original_batch_size < padded_batch_size) || needs_dim_slicing_sp; + + populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, + original_batch_size, padded_batch_size); + + // Allocate and pad for dynamic dim if needed + if (needs_dim_slicing_sp) { + void* rocm_stream_ptr; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream_ptr)); + auto rocm_stream = static_cast(rocm_stream_ptr); + allocate_and_pad_inputs_dim(mgx_state, ctx, dyn_dim_sp, + original_dim_size_sp, padded_dim_size_sp, rocm_stream); + } - // Bind inputs and allocate outputs + std::vector temp_output_buffers_sp; auto [m, prog_output_indices] = handle_program_input_outputs( - param_shapes, output_shapes, map_input_name_index, ctx); + param_shapes, output_shapes, map_input_name_index, ctx, + needs_any_slicing, needs_any_slicing ? &temp_output_buffers_sp : nullptr); - // Complete cache population mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering - mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( + mgx_state, ctx, 0, + needs_dim_slicing_sp ? dyn_dim_sp : -1, + needs_dim_slicing_sp ? padded_dim_size_sp : 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + // Rebind padded inputs if dim padding was applied + if (needs_dim_slicing_sp && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { + for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { + const auto& inp = mgx_state->cached_inputs[i]; + const auto& padded_buf = mgx_state->padded_input_buffers[i]; + m.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); + } + } + run_migraphx_program(mgx_state->mgx_mu_ptr, api, context, ctx, prog, m, prog_output_indices, - original_batch_size, padded_batch_size); + original_batch_size, padded_batch_size, + needs_dim_slicing_sp ? dyn_dim_sp : -1, + original_dim_size_sp, padded_dim_size_sp); + + if (needs_any_slicing) { + for (void* buf : temp_output_buffers_sp) { + if (buf != nullptr) (void)hipFree(buf); + } + } } // Build MIGraphX ONNX options with default shapes for symbolic dimensions @@ -4065,6 +4396,24 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] Static model mode for node '" << context->node_name << "'"; LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] defer_compilation=" << p->defer_compilation; } + + // Initialize dynamic dimension support if configured + if (dynamic_dimension_index_ > 0 && max_dynamic_dim_size_ > 0) { + p->dynamic_dimension_index = dynamic_dimension_index_; + p->max_dynamic_dim_size = max_dynamic_dim_size_; + p->has_dynamic_dim = true; + p->compiled_dim_sizes = generate_power_of_two_batch_sizes(max_dynamic_dim_size_); + LOGS_DEFAULT(INFO) << "[Compile][CREATE_STATE] Dynamic dimension enabled for node '" + << context->node_name << "': dim_index=" << dynamic_dimension_index_ + << ", max_size=" << max_dynamic_dim_size_ + << ", compiled sizes count=" << p->compiled_dim_sizes.size(); + // When we have a dynamic dim, we need to defer compilation if not already deferred, + // since the model shape depends on runtime values for the dynamic dimension + if (!p->defer_compilation && !p->has_dynamic_batch) { + p->defer_compilation = true; + LOGS_DEFAULT(INFO) << "[Compile][CREATE_STATE] Setting defer_compilation=true for dynamic dimension"; + } + } *state = p.release(); return 0; @@ -4090,12 +4439,36 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // ═══════════════════════════════════════════════════════════════════════ // Build input shape hash - only computed when shapes change + // When dynamic dim is enabled, replace that dimension with max size in the + // hash so all runtime sizes map to the same compiled program. // ═══════════════════════════════════════════════════════════════════════ std::vector all_input_shapes; all_input_shapes.reserve(map_input_name_index.size() * 4); + + std::size_t runtime_dim_size = 0; // Actual runtime size of the dynamic dim + std::size_t padded_dim_size = 0; // Padded size for the dynamic dim + const int dyn_dim_idx = mgx_state->dynamic_dimension_index; + const bool has_dyn_dim = mgx_state->has_dynamic_dim; + for (const auto& [name, index] : map_input_name_index) { const auto& shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); - all_input_shapes.insert(all_input_shapes.end(), shape.begin(), shape.end()); + for (int d = 0; d < static_cast(shape.size()); ++d) { + if (has_dyn_dim && d == dyn_dim_idx) { + // Capture the runtime size from the first input that has this dim + if (runtime_dim_size == 0) { + runtime_dim_size = static_cast(shape[d]); + padded_dim_size = find_nearest_compiled_batch_size( + runtime_dim_size, mgx_state->compiled_dim_sizes); + if (padded_dim_size == 0) { + padded_dim_size = mgx_state->max_dynamic_dim_size; + } + } + // Replace with the padded/max size for hash consistency + all_input_shapes.push_back(static_cast(padded_dim_size)); + } else { + all_input_shapes.push_back(shape[d]); + } + } } const auto current_hash = make_hash(all_input_shapes); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index e14da5c2c7639..dceb604bef99b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -37,6 +37,8 @@ constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; constexpr auto kModelMaxDynamicBatch = "ORT_MIGRAPHX_MAX_DYNAMIC_BATCH"sv; constexpr auto kCompileBatches = "ORT_MIGRAPHX_COMPILE_BATCHES"sv; +constexpr auto kDynamicDimensionIndex = "ORT_MIGRAPHX_DYNAMIC_DIMENSION_INDEX"sv; +constexpr auto kMaxDynamicDimSize = "ORT_MIGRAPHX_MAX_DYNAMIC_DIM_SIZE"sv; } // namespace migraphx_env_vars // Tracks which dimensions are symbolic for a given input @@ -73,6 +75,12 @@ struct MIGraphXFuncState { // Dynamic batch support bool has_dynamic_batch = false; std::vector compiled_batch_sizes; + + // Dynamic dimension support (non-batch dimension) + int dynamic_dimension_index = -1; // Which dimension to treat as dynamic (-1 = disabled) + size_t max_dynamic_dim_size = 0; // Max size for that dimension + bool has_dynamic_dim = false; // True if dynamic_dimension_index >= 0 and max_dynamic_dim_size > 0 + std::vector compiled_dim_sizes; // Compiled sizes for the dynamic dimension // Padded input buffers for dynamic batching (allocated on GPU) struct PaddedBuffer { @@ -86,6 +94,10 @@ struct MIGraphXFuncState { std::size_t last_original_batch_size = 0; // Original batch size from last run std::size_t last_padded_batch_size = 0; // Padded batch size from last run + // Track last dynamic dim sizes for reuse + std::size_t last_original_dim_size = 0; // Original dynamic dim size from last run + std::size_t last_padded_dim_size = 0; // Padded dynamic dim size from last run + // ═══════════════════════════════════════════════════════════════════════════ // PERFORMANCE CACHES - Avoid redundant MIGraphX API calls per inference // ═══════════════════════════════════════════════════════════════════════════ @@ -208,7 +220,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}, {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}, {std::string{migraphx_provider_option::kModelMaxDynamicBatch}, MakeStringWithClassicLocale(max_dynamic_batch_)}, - {std::string{migraphx_provider_option::kCompileBatches}, compile_batches_}}; + {std::string{migraphx_provider_option::kCompileBatches}, compile_batches_}, + {std::string{migraphx_provider_option::kDynamicDimensionIndex}, MakeStringWithClassicLocale(dynamic_dimension_index_)}, + {std::string{migraphx_provider_option::kMaxDynamicDimSize}, MakeStringWithClassicLocale(max_dynamic_dim_size_)}}; } private: @@ -250,6 +264,8 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool first_start_ = true; size_t max_dynamic_batch_{0}; std::string compile_batches_{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" + int dynamic_dimension_index_{-1}; // Non-batch dimension to treat as dynamic (-1 = disabled) + size_t max_dynamic_dim_size_{0}; // Max size for the dynamic dimension (0 = disabled) }; }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 08bdd13dfc763..e40159b1e6d28 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -74,6 +74,8 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .AddAssignmentToReference(migraphx_provider_option::kModelMaxDynamicBatch, max_dynamic_batch) .AddAssignmentToReference(migraphx_provider_option::kCompileBatches, compile_batches) + .AddAssignmentToReference(migraphx_provider_option::kDynamicDimensionIndex, dynamic_dimension_index) + .AddAssignmentToReference(migraphx_provider_option::kMaxDynamicDimSize, max_dynamic_dim_size) .Parse(options)); } @@ -107,6 +109,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, {std::string{migraphx_provider_option::kModelMaxDynamicBatch}, MakeStringWithClassicLocale(max_dynamic_batch)}, {std::string{migraphx_provider_option::kCompileBatches}, compile_batches}, + {std::string{migraphx_provider_option::kDynamicDimensionIndex}, MakeStringWithClassicLocale(dynamic_dimension_index)}, + {std::string{migraphx_provider_option::kMaxDynamicDimSize}, MakeStringWithClassicLocale(max_dynamic_dim_size)}, }; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 0c9802460c3c7..6c5dd87927ab8 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -36,6 +36,8 @@ constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; constexpr auto kModelMaxDynamicBatch = "migraphx_max_dynamic_batch"sv; constexpr auto kCompileBatches = "migraphx_compile_batches"sv; +constexpr auto kDynamicDimensionIndex = "migraphx_dynamic_dimension_index"sv; +constexpr auto kMaxDynamicDimSize = "migraphx_max_dynamic_dim_size"sv; } // namespace migraphx_provider_option extern const EnumNameMapping arena_extend_strategy_mapping; @@ -59,6 +61,8 @@ struct MIGraphXExecutionProviderInfo { OrtArenaCfg* default_memory_arena_cfg{nullptr}; size_t max_dynamic_batch{static_cast(0)}; std::string compile_batches{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" + int dynamic_dimension_index{-1}; // Non-batch dimension to treat as dynamic (-1 = disabled) + size_t max_dynamic_dim_size{static_cast(0)}; // Max size for the dynamic dimension (0 = disabled) void* external_alloc{nullptr}; void* external_free{nullptr}; @@ -106,6 +110,8 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { onnxruntime::HashCombine(info.max_dynamic_batch, value); onnxruntime::HashCombine(info.compile_batches, value); + onnxruntime::HashCombine(static_cast(info.dynamic_dimension_index), value); + onnxruntime::HashCombine(info.max_dynamic_dim_size, value); // The default memory arena cfg is not used in hashing right now. return value; }