1616import hashlib
1717import logging
1818import os
19- from typing import Optional , Union
19+ from typing import Literal , Optional , Union
2020
2121import torch
2222from peft import PeftModel
3737)
3838from lmflow .utils .conversation_template import PRESET_TEMPLATES
3939from lmflow .utils .data_utils import VLLMInferenceResultWithInput
40+ from lmflow .utils .deprecated import deprecated_args
4041from lmflow .utils .envs import is_accelerate_env
4142from lmflow .utils .versioning import is_flash_attn_available , is_ray_available , is_vllm_available
4243
@@ -273,43 +274,85 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]:
273274 # Can be list of ints or a Tensor
274275 return self .tokenizer .decode (input , ** kwargs )
275276
276- def inference (self , inputs , release_gpu : bool = False , use_vllm : bool = False , ** kwargs ):
277+ @deprecated_args (
278+ use_vllm = {
279+ "replacement" : "inference_engine" ,
280+ "mapper" : lambda x : "vllm" if x is True else "huggingface" ,
281+ "message" : (
282+ "use_vllm is deprecated and will be removed in a future version. "
283+ "Please use `inference_engine='vllm'` instead."
284+ ),
285+ }
286+ )
287+ def inference (
288+ self ,
289+ inputs : Union [str , list [str ], torch .Tensor ],
290+ sampling_params : Optional [Union [dict , "SamplingParams" ]] = None ,
291+ return_logprob : bool = False ,
292+ release_gpu : bool = False ,
293+ inference_engine : Literal ["huggingface" , "vllm" , "sglang" ] = "huggingface" ,
294+ gpu_memory_utilization : Optional [float ] = None ,
295+ tensor_parallel_size : Optional [int ] = None ,
296+ enable_deterministic_inference : bool = False ,
297+ attention_backend : Optional [str ] = None ,
298+ ** kwargs ,
299+ ):
277300 """
278301 Perform generation process of the model.
279302
280303 Parameters
281304 ------------
282- inputs :
305+ inputs : Union[str, list[str], torch.Tensor]
283306 The sequence used as a prompt for the generation or as model inputs to the model.
284- When using vllm inference, this should be a string or a list of strings.
285- When using normal inference, this should be a tensor.
307+ When the inference engine is "vllm" or "sglang", this should be a string or a list of strings.
308+ When the inference engine is "huggingface", this should be a tensor.
309+ sampling_params : Optional[Union[dict, "SamplingParams"]], optional
310+ The sampling parameters to use, by default None.
311+ return_logprob : bool, optional
312+ Whether to return log probability during inference, by default False.
286313 release_gpu : bool, optional
287314 Whether to release the GPU resource after inference, by default False.
288- use_vllm : bool, optional
289- Whether to use VLLM for inference, by default False.
290- kwargs : Optional.
291- Keyword arguments.
315+ inference_engine : Literal["huggingface", "vllm", "sglang"], optional
316+ The inference engine to use, by default "huggingface".
317+ gpu_memory_utilization : float, optional
318+ The GPU memory utilization to use, by default None.
319+ tensor_parallel_size : int, optional
320+ The tensor parallel size to use, by default None.
321+ enable_deterministic_inference : bool, optional
322+ Whether to enable deterministic inference, by default False.
323+ attention_backend : Optional[str], optional
324+ The attention backend to use, by default None.
292325
293326 Returns
294327 ------------
295328 outputs :
296329 The generated sequence output
297330 """
331+ if isinstance (inputs , str ):
332+ inputs = [inputs ]
333+
298334 if not self ._activated :
299335 self .activate_model_for_inference (
300- use_vllm = use_vllm ,
301- ** kwargs ,
336+ inference_engine = inference_engine ,
337+ gpu_memory_utilization = gpu_memory_utilization ,
338+ tensor_parallel_size = tensor_parallel_size ,
339+ enable_deterministic_inference = enable_deterministic_inference ,
340+ attention_backend = attention_backend ,
302341 )
303342
304- if use_vllm :
305- if not is_vllm_available ():
306- raise ImportError ("vllm is not installed. Please install vllm to use VLLM inference." )
307- res = self .__vllm_inference (inputs , ** kwargs )
343+ if inference_engine == "vllm" :
344+ res = self .__vllm_inference (inputs = inputs , sampling_params = sampling_params )
345+ elif inference_engine == "sglang" :
346+ res = self .__sglang_inference (
347+ inputs = inputs ,
348+ sampling_params = sampling_params ,
349+ return_logprob = return_logprob ,
350+ )
308351 else :
309- res = self .__inference (inputs , ** kwargs )
352+ res = self .__inference (inputs = inputs , ** kwargs )
310353
311354 if release_gpu :
312- self .deactivate_model_for_inference (use_vllm = use_vllm )
355+ self .deactivate_model_for_inference (inference_engine = inference_engine )
313356
314357 return res
315358
@@ -353,15 +396,14 @@ def __inference(self, inputs, *args, **kwargs):
353396
354397 def __vllm_inference (
355398 self ,
356- inputs : Union [ str , list [str ] ],
399+ inputs : list [str ],
357400 sampling_params : Optional ["SamplingParams" ] = None ,
358- ** kwargs ,
359401 ) -> list [VLLMInferenceResultWithInput ]:
360402 """Perform VLLM inference process of the model.
361403
362404 Parameters
363405 ----------
364- inputs : Union[str, list[str] ]
406+ inputs : list[str]
365407 Prompt(s), string or a list of strings.
366408 sampling_params : Optional[SamplingParams], optional
367409 vllm SamplingParams object, by default None.
@@ -383,6 +425,7 @@ def __vllm_inference(
383425 sampling_params = sampling_params ,
384426 use_tqdm = True ,
385427 )
428+ # TODO: unified lmflow sample format
386429 final_output = []
387430 for output in vllm_outputs :
388431 if sampling_params .detokenize :
@@ -394,54 +437,39 @@ def __vllm_inference(
394437
395438 return final_output
396439
397- def prepare_inputs_for_inference (
440+ def __sglang_inference (
398441 self ,
399- dataset : Dataset ,
400- apply_chat_template : bool = True ,
401- enable_distributed_inference : bool = False ,
402- use_vllm : bool = False ,
403- ** kwargs ,
404- ) -> Union [list [str ], "ray.data.Dataset" , dict [str , torch .Tensor ]]:
405- """
406- Prepare inputs for inference.
407-
408- Parameters
409- ------------
410- dataset : lmflow.datasets.Dataset.
411- The dataset used for inference.
412-
413- args : Optional.
414- Positional arguments.
415-
416- kwargs : Optional.
417- Keyword arguments.
418-
419- Returns
420- ------------
421- outputs :
422- The prepared inputs for inference.
423- """
424- if use_vllm :
425- if not is_ray_available () and enable_distributed_inference :
426- raise ImportError ("ray is not installed. Please install ray to use distributed vllm inference." )
427- inference_inputs = self .__prepare_inputs_for_vllm_inference (
428- dataset = dataset ,
429- apply_chat_template = apply_chat_template ,
430- enable_distributed_inference = enable_distributed_inference ,
431- )
432- else :
433- inference_inputs = self .__prepare_inputs_for_inference (
434- dataset ,
435- apply_chat_template = apply_chat_template ,
436- enable_distributed_inference = enable_distributed_inference ,
437- )
438-
439- return inference_inputs
440-
441- def __prepare_inputs_for_vllm_inference (
442+ inputs : list [str ],
443+ sampling_params : Optional [dict ] = None ,
444+ return_logprob : bool = False ,
445+ ):
446+ """Perform SGLang inference process of the model."""
447+ sglang_outputs = self .backend_model_for_inference .generate (
448+ prompt = inputs ,
449+ sampling_params = sampling_params ,
450+ return_logprob = return_logprob ,
451+ )
452+ # TODO: unified lmflow sample format
453+ for idx , output in enumerate (sglang_outputs ):
454+ output ["input" ] = inputs [idx ]
455+ output ["output" ] = output .pop ("text" )
456+ return sglang_outputs
457+
458+ @deprecated_args (
459+ use_vllm = {
460+ "replacement" : "inference_engine" ,
461+ "mapper" : lambda x : "vllm" if x is True else "huggingface" ,
462+ "message" : (
463+ "use_vllm is deprecated and will be removed in a future version. "
464+ "Please use `inference_engine='vllm'` instead."
465+ ),
466+ }
467+ )
468+ def prepare_inputs_for_inference (
442469 self ,
443470 dataset : Dataset ,
444471 apply_chat_template : bool = True ,
472+ inference_engine : Literal ["huggingface" , "vllm" , "sglang" ] = "huggingface" ,
445473 enable_distributed_inference : bool = False ,
446474 ) -> Union [list [str ], "ray.data.Dataset" ]:
447475 if dataset .get_type () == "text_only" :
@@ -498,6 +526,7 @@ def preprocess_conversation(sample):
498526
499527 return sample_out
500528
529+ # TODO: investigate performance issue
501530 dataset = dataset .map (
502531 preprocess_conversation ,
503532 num_proc = dataset .data_args .preprocessing_num_workers ,
@@ -517,19 +546,16 @@ def preprocess_conversation(sample):
517546
518547 inference_inputs = [sentence for sentence in inference_inputs if len (sentence ) > 0 ]
519548
520- if enable_distributed_inference :
549+ if inference_engine == "vllm" and enable_distributed_inference :
521550 inference_inputs = ray .data .from_items (
522551 inference_inputs
523552 ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}
524553
525- return inference_inputs
554+ if inference_engine == "sglang" and self .tokenizer .bos_token :
555+ # in consistent with sglang bench_serving.py demo
556+ inference_inputs = [sentence .replace (self .tokenizer .bos_token , "" ) for sentence in inference_inputs ]
526557
527- def __prepare_inputs_for_inference (
528- self ,
529- dataset : Dataset ,
530- ** kwargs ,
531- ):
532- raise NotImplementedError ("prepare_inputs_for_inference is not implemented" )
558+ return inference_inputs
533559
534560 def merge_lora_weights (self ):
535561 if self .model_args .use_lora and not self .model_args .use_qlora :
0 commit comments