@@ -332,51 +332,76 @@ void GgmlOvDecoder::validate_cgraph() const {
332332
333333ov::PartialShape GgmlOvDecoder::get_graph_input_shape (const ggml_tensor * op, const ggml_tensor * input) const {
334334 if (m_naive) {
335- return input!= nullptr ? ov::PartialShape{get_shape (input)} : ov::PartialShape{get_shape (op)};
335+ return input != nullptr ? ov::PartialShape{get_shape (input)} : ov::PartialShape{get_shape (op)};
336336 }
337337 auto name = std::string (input->name );
338338 ov::PartialShape input_shape;
339339
340+ // ggml_tensor gives us exact measurements in all cases, so none of those should be -1 (as that's an
341+ // OpenVINO native convention). All of those are passed thorugh directly.
342+ //
343+ // Cases where a tensor dimension size varies are handled case-by-case below. We provide a PartialShape to
344+ // communicate the worst-case scenario: a PartialShape has a lower and upper bound on the dimension,
345+ // used to inform OpenVINO optimizations. An issue was observed with OpenCL remote buffers not allocating
346+ // unless such a range was provided (considerations with remote memory). Although that's not the responsibility
347+ // of llama.cpp to solve, providing dimension bounds is useful nonetheless.
348+
349+ const auto prefill_upper = m_is_prefill ? m_prefill_chunk_size : 1 ;
350+ const auto dim_span_ctx = ov::Dimension (1 , m_model_params.ctx );
351+
340352 if (is_inp_tok (input, op) || is_inp_pos (input, op)) {
341353 // tokens or positions
342- int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1 ) : -1 ;
343- input_shape = ov::PartialShape{1 , 1 , 1 , len};
354+ if (m_is_static) {
355+ input_shape = ov::PartialShape{1 , 1 , 1 , prefill_upper};
356+ } else {
357+ // NOTE: AFAICT PartialShape with min_dim == max_dim is not valid, Shape must be used.
358+ // Updating callsites to return some Shape optional to allow 1,1,1,1
359+ // might materially improve things
360+ input_shape = ov::PartialShape{1 , 1 , 1 , ov::Dimension (1 , m_prefill_chunk_size)};
361+ }
344362
345363 } else if (is_output_idx (input, op)) {
346364 // output index
347- input_shape = ov::PartialShape{1 , 1 , 1 , m_is_static ? m_compute_params.output_len : -1 };
348-
365+ if (m_is_static) {
366+ input_shape = ov::PartialShape{1 , 1 , 1 , m_compute_params.output_len };
367+ } else {
368+ input_shape = ov::PartialShape{1 , 1 , 1 , ov::Dimension (1 , m_compute_params.output_len )};
369+ }
349370 } else if (is_inp_mask (input, op)) {
350371 // mask
351372 if (m_is_static) {
352- input_shape = ov::PartialShape{1 , 1 , m_is_prefill ? m_prefill_chunk_size : 1 , m_model_params.ctx };
373+ input_shape = ov::PartialShape{1 , 1 , prefill_upper , m_model_params.ctx };
353374 } else if (m_is_stateful) {
354- input_shape = ov::PartialShape{1 , 1 , - 1 , - 1 };
375+ input_shape = ov::PartialShape{1 , 1 , dim_span_ctx, dim_span_ctx };
355376 } else {
356- input_shape = ov::PartialShape{- 1 , 1 , - 1 , - 1 };
377+ input_shape = ov::PartialShape{dim_span_ctx , 1 , dim_span_ctx, dim_span_ctx };
357378 }
358379
359380 } else if (is_kvcache (input, op)) {
360381 // kvcache
361382 input_shape = ov::PartialShape{get_shape (input)};
362383 if (!m_is_static) {
363384 // do not fix ctx size to make llama-bench work across test params
364- input_shape[2 ] = - 1 ;
385+ input_shape[2 ] = dim_span_ctx ;
365386 }
366387 if (is_stateful ()) {
367388 // Convert stateless KV cache layout [1, 1, seq, n_heads_kv * head_size]
368389 // to stateful layout [1, seq, n_heads_kv, head_size].
369390 assert (input_shape.size () == 4 && input_shape[0 ] == 1 && input_shape[1 ] == 1 &&
370391 input_shape[2 ].is_dynamic () &&
371392 input_shape[3 ] == (m_model_params.n_heads_kv * m_model_params.head_size ));
372- input_shape = {input_shape[0 ], ov::Dimension::dynamic () , m_model_params.n_heads_kv ,
393+ input_shape = {input_shape[0 ], dim_span_ctx , m_model_params.n_heads_kv ,
373394 m_model_params.head_size };
374395 }
375396
376397 } else if (is_kv_idx (input, op)) {
377398 // kv update index
378- int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1 ) : -1 ;
379- input_shape = ov::PartialShape{1 , 1 , 1 , len};
399+ if (m_is_static) {
400+ int len = m_is_prefill ? m_prefill_chunk_size : 1 ;
401+ input_shape = ov::PartialShape{1 , 1 , 1 , len};
402+ } else {
403+ input_shape = ov::PartialShape{1 , 1 , 1 , dim_span_ctx};
404+ }
380405
381406 } else {
382407 input_shape = ov::PartialShape{get_shape (input)};
0 commit comments