1212
1313import asyncio
1414import time
15- from collections .abc import AsyncIterator
15+ from collections .abc import AsyncIterator , Callable
1616from typing import Any
1717
1818import httpx
1919
2020from guidellm .backends .backend import Backend
2121from guidellm .backends .response_handlers import GenerationResponseHandlerFactory
22- from guidellm .schemas import GenerationRequest , GenerationResponse , RequestInfo
22+ from guidellm .schemas import (
23+ GenerationRequest ,
24+ GenerationRequestArguments ,
25+ GenerationResponse ,
26+ RequestInfo ,
27+ )
2328
2429__all__ = ["OpenAIHTTPBackend" ]
2530
@@ -59,6 +64,10 @@ def __init__(
5964 follow_redirects : bool = True ,
6065 verify : bool = False ,
6166 validate_backend : bool | str | dict [str , Any ] = True ,
67+ stream : bool = True ,
68+ extras : dict [str , Any ] | GenerationRequestArguments | None = None ,
69+ max_tokens : int | None = None ,
70+ max_completion_tokens : int | None = None ,
6271 ):
6372 """
6473 Initialize OpenAI HTTP backend with server configuration.
@@ -96,11 +105,28 @@ def __init__(
96105 self .validate_backend : dict [str , Any ] | None = self ._resolve_validate_kwargs (
97106 validate_backend
98107 )
108+ self .stream : bool = stream
109+ self .extras = (
110+ GenerationRequestArguments (** extras )
111+ if extras and isinstance (extras , dict )
112+ else extras
113+ )
114+ self .max_tokens : int | None = max_tokens or max_completion_tokens
99115
100116 # Runtime state
101117 self ._in_process = False
102118 self ._async_client : httpx .AsyncClient | None = None
103119
120+ # TODO: Find a better way to register formatters
121+ self .request_formatters : dict [
122+ str , Callable [[GenerationRequest ], GenerationRequestArguments ]
123+ ] = {
124+ "text_completions" : self .formatter_text_completions ,
125+ "chat_completions" : self .formatter_chat_completions ,
126+ "audio_transcriptions" : self .formatter_audio_transcriptions ,
127+ "audio_translations" : self .formatter_audio_transcriptions ,
128+ }
129+
104130 @property
105131 def info (self ) -> dict [str , Any ]:
106132 """
@@ -233,31 +259,35 @@ async def resolve( # type: ignore[override]
233259 if history is not None :
234260 raise NotImplementedError ("Multi-turn requests not yet supported" )
235261
262+ arguments : GenerationRequestArguments = self .request_formatters [
263+ request .request_type
264+ ](request )
265+
236266 if (request_path := self .api_routes .get (request .request_type )) is None :
237267 raise ValueError (f"Unsupported request type '{ request .request_type } '" )
238268
239269 request_url = f"{ self .target } /{ request_path } "
240270 request_files = (
241271 {
242272 key : tuple (value ) if isinstance (value , list ) else value
243- for key , value in request . arguments .files .items ()
273+ for key , value in arguments .files .items ()
244274 }
245- if request . arguments .files
275+ if arguments .files
246276 else None
247277 )
248- request_json = request . arguments .body if not request_files else None
249- request_data = request . arguments .body if request_files else None
278+ request_json = arguments .body if not request_files else None
279+ request_data = arguments .body if request_files else None
250280 response_handler = GenerationResponseHandlerFactory .create (
251281 request .request_type , handler_overrides = self .response_handlers
252282 )
253283
254- if not request . arguments .stream :
284+ if not arguments .stream :
255285 request_info .timings .request_start = time .time ()
256286 response = await self ._async_client .request (
257- request . arguments .method or "POST" ,
287+ arguments .method or "POST" ,
258288 request_url ,
259- params = request . arguments .params ,
260- headers = request . arguments .headers ,
289+ params = arguments .params ,
290+ headers = arguments .headers ,
261291 json = request_json ,
262292 data = request_data ,
263293 files = request_files ,
@@ -272,10 +302,10 @@ async def resolve( # type: ignore[override]
272302 request_info .timings .request_start = time .time ()
273303
274304 async with self ._async_client .stream (
275- request . arguments .method or "POST" ,
305+ arguments .method or "POST" ,
276306 request_url ,
277- params = request . arguments .params ,
278- headers = request . arguments .headers ,
307+ params = arguments .params ,
308+ headers = arguments .headers ,
279309 json = request_json ,
280310 data = request_data ,
281311 files = request_files ,
@@ -338,3 +368,177 @@ def _resolve_validate_kwargs(
338368 validate_kwargs ["method" ] = "GET"
339369
340370 return validate_kwargs
371+
372+ def formatter_text_completions (
373+ self , data : GenerationRequest
374+ ) -> GenerationRequestArguments :
375+ arguments : GenerationRequestArguments = GenerationRequestArguments ()
376+ arguments .body = {} # The type checker works better setting this field here
377+
378+ # Add model
379+ if self .model is not None :
380+ arguments .body ["model" ] = self .model
381+
382+ # Configure streaming
383+ if self .stream :
384+ arguments .stream = True
385+ arguments .body ["stream" ] = True
386+ arguments .body ["stream_options" ] = {"include_usage" : True }
387+
388+ # Handle output tokens
389+ if data .output_metrics .text_tokens :
390+ arguments .body ["max_tokens" ] = data .output_metrics .text_tokens
391+ arguments .body ["stop" ] = None
392+ arguments .body ["ignore_eos" ] = True
393+ elif self .max_tokens is not None :
394+ arguments .body ["max_tokens" ] = self .max_tokens
395+
396+ # Apply extra arguments
397+ if self .extras :
398+ arguments .model_combine (self .extras )
399+
400+ # Build prompt
401+ prefix = "" .join (pre for pre in data .columns .get ("prefix_column" , []) if pre )
402+ text = "" .join (txt for txt in data .columns .get ("text_column" , []) if txt )
403+ if prefix or text :
404+ prompt = prefix + text
405+ arguments .body ["prompt" ] = prompt
406+
407+ return arguments
408+
409+ def formatter_chat_completions ( # noqa: C901, PLR0912, PLR0915
410+ self , data : GenerationRequest
411+ ) -> GenerationRequestArguments :
412+ arguments = GenerationRequestArguments ()
413+ arguments .body = {} # The type checker works best with body assigned here
414+
415+ # Add model
416+ if self .model is not None :
417+ arguments .body ["model" ] = self .model
418+
419+ # Configure streaming
420+ if self .stream :
421+ arguments .stream = True
422+ arguments .body ["stream" ] = True
423+ arguments .body ["stream_options" ] = {"include_usage" : True }
424+
425+ # Handle output tokens
426+ if data .output_metrics .text_tokens :
427+ arguments .body .update (
428+ {
429+ "max_completion_tokens" : data .output_metrics .text_tokens ,
430+ "stop" : None ,
431+ "ignore_eos" : True ,
432+ }
433+ )
434+ elif self .max_tokens is not None :
435+ arguments .body ["max_completion_tokens" ] = self .max_tokens
436+
437+ # Apply extra arguments
438+ if self .extras :
439+ arguments .model_combine (self .extras )
440+
441+ # Build messages
442+ arguments .body ["messages" ] = []
443+
444+ for prefix in data .columns .get ("prefix_column" , []):
445+ if not prefix :
446+ continue
447+
448+ arguments .body ["messages" ].append ({"role" : "system" , "content" : prefix })
449+
450+ for text in data .columns .get ("text_column" , []):
451+ if not text :
452+ continue
453+
454+ arguments .body ["messages" ].append (
455+ {"role" : "user" , "content" : [{"type" : "text" , "text" : text }]}
456+ )
457+
458+ for image in data .columns .get ("image_column" , []):
459+ if not image :
460+ continue
461+
462+ arguments .body ["messages" ].append (
463+ {
464+ "role" : "user" ,
465+ "content" : [{"type" : "image_url" , "image_url" : image .get ("image" )}],
466+ }
467+ )
468+
469+ for video in data .columns .get ("video_column" , []):
470+ if not video :
471+ continue
472+
473+ arguments .body ["messages" ].append (
474+ {
475+ "role" : "user" ,
476+ "content" : [{"type" : "video_url" , "video_url" : video .get ("video" )}],
477+ }
478+ )
479+
480+ for audio in data .columns .get ("audio_column" , []):
481+ if not audio :
482+ continue
483+
484+ arguments .body ["messages" ].append (
485+ {
486+ "role" : "user" ,
487+ "content" : [
488+ {
489+ "type" : "input_audio" ,
490+ "input_audio" : {
491+ "data" : audio .get ("audio" ),
492+ "format" : audio .get ("format" ),
493+ },
494+ }
495+ ],
496+ }
497+ )
498+
499+ return arguments
500+
501+ def formatter_audio_transcriptions ( # noqa: C901
502+ self , data : GenerationRequest
503+ ) -> GenerationRequestArguments :
504+ arguments = GenerationRequestArguments (files = {})
505+ arguments .body = {}
506+
507+ # Add model
508+ if self .model is not None :
509+ arguments .body ["model" ] = self .model
510+
511+ # Configure streaming
512+ if self .stream :
513+ arguments .stream = True
514+ arguments .body ["stream" ] = True
515+ arguments .body ["stream_options" ] = {"include_usage" : True }
516+
517+ # Apply extra arguments
518+ if self .extras :
519+ arguments .model_combine (self .extras )
520+
521+ # Build audio input
522+ audio_columns = data .columns .get ("audio_column" , [])
523+ if len (audio_columns ) != 1 :
524+ raise ValueError (
525+ f"GenerativeAudioTranscriptionRequestFormatter expects exactly "
526+ f"one audio column, but got { len (audio_columns )} ."
527+ )
528+
529+ arguments .files = {
530+ "file" : (
531+ audio_columns [0 ].get ("file_name" , "audio_input" ),
532+ audio_columns [0 ].get ("audio" ),
533+ audio_columns [0 ].get ("mimetype" ),
534+ )
535+ }
536+
537+ # Build prompt
538+ prefix = "" .join (pre for pre in data .columns .get ("prefix_column" , []) if pre )
539+ text = "" .join (txt for txt in data .columns .get ("text_column" , []) if txt )
540+ if prefix or text :
541+ prompt = prefix + text
542+ arguments .body ["prompt" ] = prompt
543+
544+ return arguments
0 commit comments