Skip to content

Commit b0991d4

Browse files
committed
fix(openvino): define PartialShape bounds for tensors
1 parent 2dcb7f7 commit b0991d4

1 file changed

Lines changed: 37 additions & 12 deletions

File tree

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -332,51 +332,76 @@ void GgmlOvDecoder::validate_cgraph() const {
332332

333333
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const {
334334
if (m_naive) {
335-
return input!= nullptr ? ov::PartialShape{get_shape(input)} : ov::PartialShape{get_shape(op)};
335+
return input != nullptr ? ov::PartialShape{get_shape(input)} : ov::PartialShape{get_shape(op)};
336336
}
337337
auto name = std::string(input->name);
338338
ov::PartialShape input_shape;
339339

340+
// ggml_tensor gives us exact measurements in all cases, so none of those should be -1 (as that's an
341+
// OpenVINO native convention). All of those are passed thorugh directly.
342+
//
343+
// Cases where a tensor dimension size varies are handled case-by-case below. We provide a PartialShape to
344+
// communicate the worst-case scenario: a PartialShape has a lower and upper bound on the dimension,
345+
// used to inform OpenVINO optimizations. An issue was observed with OpenCL remote buffers not allocating
346+
// unless such a range was provided (considerations with remote memory). Although that's not the responsibility
347+
// of llama.cpp to solve, providing dimension bounds is useful nonetheless.
348+
349+
const auto prefill_upper = m_is_prefill ? m_prefill_chunk_size : 1;
350+
const auto dim_span_ctx = ov::Dimension(1, m_model_params.ctx);
351+
340352
if (is_inp_tok(input, op) || is_inp_pos(input, op)) {
341353
// tokens or positions
342-
int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
343-
input_shape = ov::PartialShape{1, 1, 1, len};
354+
if (m_is_static) {
355+
input_shape = ov::PartialShape{1, 1, 1, prefill_upper};
356+
} else {
357+
// NOTE: AFAICT PartialShape with min_dim == max_dim is not valid, Shape must be used.
358+
// Updating callsites to return some Shape optional to allow 1,1,1,1
359+
// might materially improve things
360+
input_shape = ov::PartialShape{1, 1, 1, ov::Dimension(1, m_prefill_chunk_size)};
361+
}
344362

345363
} else if (is_output_idx(input, op)) {
346364
// output index
347-
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1};
348-
365+
if (m_is_static) {
366+
input_shape = ov::PartialShape{1, 1, 1, m_compute_params.output_len};
367+
} else {
368+
input_shape = ov::PartialShape{1, 1, 1, ov::Dimension(1, m_compute_params.output_len)};
369+
}
349370
} else if (is_inp_mask(input, op)) {
350371
// mask
351372
if (m_is_static) {
352-
input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};
373+
input_shape = ov::PartialShape{1, 1, prefill_upper, m_model_params.ctx};
353374
} else if (m_is_stateful) {
354-
input_shape = ov::PartialShape{1, 1, -1, -1};
375+
input_shape = ov::PartialShape{1, 1, dim_span_ctx, dim_span_ctx};
355376
} else {
356-
input_shape = ov::PartialShape{-1, 1, -1, -1};
377+
input_shape = ov::PartialShape{dim_span_ctx, 1, dim_span_ctx, dim_span_ctx};
357378
}
358379

359380
} else if (is_kvcache(input, op)) {
360381
// kvcache
361382
input_shape = ov::PartialShape{get_shape(input)};
362383
if (!m_is_static) {
363384
// do not fix ctx size to make llama-bench work across test params
364-
input_shape[2] = -1;
385+
input_shape[2] = dim_span_ctx;
365386
}
366387
if (is_stateful()) {
367388
// Convert stateless KV cache layout [1, 1, seq, n_heads_kv * head_size]
368389
// to stateful layout [1, seq, n_heads_kv, head_size].
369390
assert(input_shape.size() == 4 && input_shape[0] == 1 && input_shape[1] == 1 &&
370391
input_shape[2].is_dynamic() &&
371392
input_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size));
372-
input_shape = {input_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv,
393+
input_shape = {input_shape[0], dim_span_ctx, m_model_params.n_heads_kv,
373394
m_model_params.head_size};
374395
}
375396

376397
} else if (is_kv_idx(input, op)) {
377398
// kv update index
378-
int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
379-
input_shape = ov::PartialShape{1, 1, 1, len};
399+
if (m_is_static) {
400+
int len = m_is_prefill ? m_prefill_chunk_size : 1;
401+
input_shape = ov::PartialShape{1, 1, 1, len};
402+
} else {
403+
input_shape = ov::PartialShape{1, 1, 1, dim_span_ctx};
404+
}
380405

381406
} else {
382407
input_shape = ov::PartialShape{get_shape(input)};

0 commit comments

Comments
 (0)