Skip to content

Commit b940a48

Browse files
committed
Split formatter across multiple stages
Signed-off-by: Samuel Monson <smonson@redhat.com>
1 parent 1012c1e commit b940a48

File tree

9 files changed

+455
-445
lines changed

9 files changed

+455
-445
lines changed

src/guidellm/backends/openai.py

Lines changed: 217 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212

1313
import asyncio
1414
import time
15-
from collections.abc import AsyncIterator
15+
from collections.abc import AsyncIterator, Callable
1616
from typing import Any
1717

1818
import httpx
1919

2020
from guidellm.backends.backend import Backend
2121
from guidellm.backends.response_handlers import GenerationResponseHandlerFactory
22-
from guidellm.schemas import GenerationRequest, GenerationResponse, RequestInfo
22+
from guidellm.schemas import (
23+
GenerationRequest,
24+
GenerationRequestArguments,
25+
GenerationResponse,
26+
RequestInfo,
27+
)
2328

2429
__all__ = ["OpenAIHTTPBackend"]
2530

@@ -59,6 +64,10 @@ def __init__(
5964
follow_redirects: bool = True,
6065
verify: bool = False,
6166
validate_backend: bool | str | dict[str, Any] = True,
67+
stream: bool = True,
68+
extras: dict[str, Any] | GenerationRequestArguments | None = None,
69+
max_tokens: int | None = None,
70+
max_completion_tokens: int | None = None,
6271
):
6372
"""
6473
Initialize OpenAI HTTP backend with server configuration.
@@ -96,11 +105,28 @@ def __init__(
96105
self.validate_backend: dict[str, Any] | None = self._resolve_validate_kwargs(
97106
validate_backend
98107
)
108+
self.stream: bool = stream
109+
self.extras = (
110+
GenerationRequestArguments(**extras)
111+
if extras and isinstance(extras, dict)
112+
else extras
113+
)
114+
self.max_tokens: int | None = max_tokens or max_completion_tokens
99115

100116
# Runtime state
101117
self._in_process = False
102118
self._async_client: httpx.AsyncClient | None = None
103119

120+
# TODO: Find a better way to register formatters
121+
self.request_formatters: dict[
122+
str, Callable[[GenerationRequest], GenerationRequestArguments]
123+
] = {
124+
"text_completions": self.formatter_text_completions,
125+
"chat_completions": self.formatter_chat_completions,
126+
"audio_transcriptions": self.formatter_audio_transcriptions,
127+
"audio_translations": self.formatter_audio_transcriptions,
128+
}
129+
104130
@property
105131
def info(self) -> dict[str, Any]:
106132
"""
@@ -233,31 +259,35 @@ async def resolve( # type: ignore[override]
233259
if history is not None:
234260
raise NotImplementedError("Multi-turn requests not yet supported")
235261

262+
arguments: GenerationRequestArguments = self.request_formatters[
263+
request.request_type
264+
](request)
265+
236266
if (request_path := self.api_routes.get(request.request_type)) is None:
237267
raise ValueError(f"Unsupported request type '{request.request_type}'")
238268

239269
request_url = f"{self.target}/{request_path}"
240270
request_files = (
241271
{
242272
key: tuple(value) if isinstance(value, list) else value
243-
for key, value in request.arguments.files.items()
273+
for key, value in arguments.files.items()
244274
}
245-
if request.arguments.files
275+
if arguments.files
246276
else None
247277
)
248-
request_json = request.arguments.body if not request_files else None
249-
request_data = request.arguments.body if request_files else None
278+
request_json = arguments.body if not request_files else None
279+
request_data = arguments.body if request_files else None
250280
response_handler = GenerationResponseHandlerFactory.create(
251281
request.request_type, handler_overrides=self.response_handlers
252282
)
253283

254-
if not request.arguments.stream:
284+
if not arguments.stream:
255285
request_info.timings.request_start = time.time()
256286
response = await self._async_client.request(
257-
request.arguments.method or "POST",
287+
arguments.method or "POST",
258288
request_url,
259-
params=request.arguments.params,
260-
headers=request.arguments.headers,
289+
params=arguments.params,
290+
headers=arguments.headers,
261291
json=request_json,
262292
data=request_data,
263293
files=request_files,
@@ -272,10 +302,10 @@ async def resolve( # type: ignore[override]
272302
request_info.timings.request_start = time.time()
273303

274304
async with self._async_client.stream(
275-
request.arguments.method or "POST",
305+
arguments.method or "POST",
276306
request_url,
277-
params=request.arguments.params,
278-
headers=request.arguments.headers,
307+
params=arguments.params,
308+
headers=arguments.headers,
279309
json=request_json,
280310
data=request_data,
281311
files=request_files,
@@ -338,3 +368,177 @@ def _resolve_validate_kwargs(
338368
validate_kwargs["method"] = "GET"
339369

340370
return validate_kwargs
371+
372+
def formatter_text_completions(
373+
self, data: GenerationRequest
374+
) -> GenerationRequestArguments:
375+
arguments: GenerationRequestArguments = GenerationRequestArguments()
376+
arguments.body = {} # The type checker works better setting this field here
377+
378+
# Add model
379+
if self.model is not None:
380+
arguments.body["model"] = self.model
381+
382+
# Configure streaming
383+
if self.stream:
384+
arguments.stream = True
385+
arguments.body["stream"] = True
386+
arguments.body["stream_options"] = {"include_usage": True}
387+
388+
# Handle output tokens
389+
if data.output_metrics.text_tokens:
390+
arguments.body["max_tokens"] = data.output_metrics.text_tokens
391+
arguments.body["stop"] = None
392+
arguments.body["ignore_eos"] = True
393+
elif self.max_tokens is not None:
394+
arguments.body["max_tokens"] = self.max_tokens
395+
396+
# Apply extra arguments
397+
if self.extras:
398+
arguments.model_combine(self.extras)
399+
400+
# Build prompt
401+
prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre)
402+
text = "".join(txt for txt in data.columns.get("text_column", []) if txt)
403+
if prefix or text:
404+
prompt = prefix + text
405+
arguments.body["prompt"] = prompt
406+
407+
return arguments
408+
409+
def formatter_chat_completions( # noqa: C901, PLR0912, PLR0915
410+
self, data: GenerationRequest
411+
) -> GenerationRequestArguments:
412+
arguments = GenerationRequestArguments()
413+
arguments.body = {} # The type checker works best with body assigned here
414+
415+
# Add model
416+
if self.model is not None:
417+
arguments.body["model"] = self.model
418+
419+
# Configure streaming
420+
if self.stream:
421+
arguments.stream = True
422+
arguments.body["stream"] = True
423+
arguments.body["stream_options"] = {"include_usage": True}
424+
425+
# Handle output tokens
426+
if data.output_metrics.text_tokens:
427+
arguments.body.update(
428+
{
429+
"max_completion_tokens": data.output_metrics.text_tokens,
430+
"stop": None,
431+
"ignore_eos": True,
432+
}
433+
)
434+
elif self.max_tokens is not None:
435+
arguments.body["max_completion_tokens"] = self.max_tokens
436+
437+
# Apply extra arguments
438+
if self.extras:
439+
arguments.model_combine(self.extras)
440+
441+
# Build messages
442+
arguments.body["messages"] = []
443+
444+
for prefix in data.columns.get("prefix_column", []):
445+
if not prefix:
446+
continue
447+
448+
arguments.body["messages"].append({"role": "system", "content": prefix})
449+
450+
for text in data.columns.get("text_column", []):
451+
if not text:
452+
continue
453+
454+
arguments.body["messages"].append(
455+
{"role": "user", "content": [{"type": "text", "text": text}]}
456+
)
457+
458+
for image in data.columns.get("image_column", []):
459+
if not image:
460+
continue
461+
462+
arguments.body["messages"].append(
463+
{
464+
"role": "user",
465+
"content": [{"type": "image_url", "image_url": image.get("image")}],
466+
}
467+
)
468+
469+
for video in data.columns.get("video_column", []):
470+
if not video:
471+
continue
472+
473+
arguments.body["messages"].append(
474+
{
475+
"role": "user",
476+
"content": [{"type": "video_url", "video_url": video.get("video")}],
477+
}
478+
)
479+
480+
for audio in data.columns.get("audio_column", []):
481+
if not audio:
482+
continue
483+
484+
arguments.body["messages"].append(
485+
{
486+
"role": "user",
487+
"content": [
488+
{
489+
"type": "input_audio",
490+
"input_audio": {
491+
"data": audio.get("audio"),
492+
"format": audio.get("format"),
493+
},
494+
}
495+
],
496+
}
497+
)
498+
499+
return arguments
500+
501+
def formatter_audio_transcriptions( # noqa: C901
502+
self, data: GenerationRequest
503+
) -> GenerationRequestArguments:
504+
arguments = GenerationRequestArguments(files={})
505+
arguments.body = {}
506+
507+
# Add model
508+
if self.model is not None:
509+
arguments.body["model"] = self.model
510+
511+
# Configure streaming
512+
if self.stream:
513+
arguments.stream = True
514+
arguments.body["stream"] = True
515+
arguments.body["stream_options"] = {"include_usage": True}
516+
517+
# Apply extra arguments
518+
if self.extras:
519+
arguments.model_combine(self.extras)
520+
521+
# Build audio input
522+
audio_columns = data.columns.get("audio_column", [])
523+
if len(audio_columns) != 1:
524+
raise ValueError(
525+
f"GenerativeAudioTranscriptionRequestFormatter expects exactly "
526+
f"one audio column, but got {len(audio_columns)}."
527+
)
528+
529+
arguments.files = {
530+
"file": (
531+
audio_columns[0].get("file_name", "audio_input"),
532+
audio_columns[0].get("audio"),
533+
audio_columns[0].get("mimetype"),
534+
)
535+
}
536+
537+
# Build prompt
538+
prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre)
539+
text = "".join(txt for txt in data.columns.get("text_column", []) if txt)
540+
if prefix or text:
541+
prompt = prefix + text
542+
arguments.body["prompt"] = prompt
543+
544+
return arguments

src/guidellm/data/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
DatasetDeserializerFactory,
77
)
88
from .entrypoints import process_dataset
9+
from .finalizers import DatasetFinalizer, FinalizerRegistry
910
from .loaders import DataLoader, DatasetsIterator
1011
from .preprocessors import (
1112
DataDependentPreprocessor,
1213
DatasetPreprocessor,
1314
PreprocessorRegistry,
14-
RequestFormatter,
1515
)
1616
from .processor import ProcessorFactory
1717
from .schemas import GenerativeDatasetColumnType
@@ -22,13 +22,14 @@
2222
"DataNotSupportedError",
2323
"DatasetDeserializer",
2424
"DatasetDeserializerFactory",
25+
"DatasetFinalizer",
2526
"DatasetPreprocessor",
2627
"DatasetsIterator",
28+
"FinalizerRegistry",
2729
"GenerativeDatasetColumnType",
2830
"GenerativeRequestCollator",
2931
"PreprocessorRegistry",
3032
"ProcessorFactory",
31-
"RequestFormatter",
3233
"ShortPromptStrategy",
3334
"process_dataset",
3435
]

0 commit comments

Comments
 (0)