@@ -434,8 +434,8 @@ def test_context_size_exceeded_stream():
434434@pytest .mark .parametrize (
435435 "n_batch,batch_count,reuse_cache" ,
436436 [
437- (64 , 3 , False ),
438- (64 , 1 , True ),
437+ (64 , 4 , False ),
438+ (64 , 2 , True ),
439439 ]
440440)
441441def test_return_progress (n_batch , batch_count , reuse_cache ):
@@ -462,17 +462,26 @@ def make_cmpl_request():
462462 res = make_cmpl_request ()
463463 last_progress = None
464464 total_batch_count = 0
465+
465466 for data in res :
466467 cur_progress = data .get ("prompt_progress" , None )
467468 if cur_progress is None :
468469 continue
470+ if total_batch_count == 0 :
471+ # first progress report must have n_cache == n_processed
472+ assert cur_progress ["total" ] > 0
473+ assert cur_progress ["cache" ] == cur_progress ["processed" ]
474+ if reuse_cache :
475+ # when reusing cache, we expect some cached tokens
476+ assert cur_progress ["cache" ] > 0
469477 if last_progress is not None :
470478 assert cur_progress ["total" ] == last_progress ["total" ]
471479 assert cur_progress ["cache" ] == last_progress ["cache" ]
472480 assert cur_progress ["processed" ] > last_progress ["processed" ]
473481 total_batch_count += 1
474482 last_progress = cur_progress
475483
484+ # last progress should indicate completion (all tokens processed)
476485 assert last_progress is not None
477486 assert last_progress ["total" ] > 0
478487 assert last_progress ["processed" ] == last_progress ["total" ]
0 commit comments