Skip to content

Commit 5ee4e43

Browse files
authored
server: return_progress to also report 0% processing state (#18305)
1 parent 5b6c9bc commit 5ee4e43

2 files changed

Lines changed: 17 additions & 2 deletions

File tree

tools/server/server-context.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,6 +2313,12 @@ struct server_context_impl {
23132313
slot.n_prompt_tokens_processed = 0;
23142314

23152315
slot.prompt.tokens.keep_first(n_past);
2316+
2317+
// send initial 0% progress update if needed
2318+
// this is to signal the client that the request has started processing
2319+
if (slot.task->params.stream && slot.task->params.return_progress) {
2320+
send_partial_response(slot, {}, true);
2321+
}
23162322
}
23172323

23182324
if (!slot.can_split()) {

tools/server/tests/unit/test_chat_completion.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def test_context_size_exceeded_stream():
434434
@pytest.mark.parametrize(
435435
"n_batch,batch_count,reuse_cache",
436436
[
437-
(64, 3, False),
438-
(64, 1, True),
437+
(64, 4, False),
438+
(64, 2, True),
439439
]
440440
)
441441
def test_return_progress(n_batch, batch_count, reuse_cache):
@@ -462,17 +462,26 @@ def make_cmpl_request():
462462
res = make_cmpl_request()
463463
last_progress = None
464464
total_batch_count = 0
465+
465466
for data in res:
466467
cur_progress = data.get("prompt_progress", None)
467468
if cur_progress is None:
468469
continue
470+
if total_batch_count == 0:
471+
# first progress report must have n_cache == n_processed
472+
assert cur_progress["total"] > 0
473+
assert cur_progress["cache"] == cur_progress["processed"]
474+
if reuse_cache:
475+
# when reusing cache, we expect some cached tokens
476+
assert cur_progress["cache"] > 0
469477
if last_progress is not None:
470478
assert cur_progress["total"] == last_progress["total"]
471479
assert cur_progress["cache"] == last_progress["cache"]
472480
assert cur_progress["processed"] > last_progress["processed"]
473481
total_batch_count += 1
474482
last_progress = cur_progress
475483

484+
# last progress should indicate completion (all tokens processed)
476485
assert last_progress is not None
477486
assert last_progress["total"] > 0
478487
assert last_progress["processed"] == last_progress["total"]

0 commit comments

Comments
 (0)