Skip to content

Commit a1882bb

Browse files
authored
Merge pull request #955 from OptimalScale/lmflow-sgl
[feature] sglang support
2 parents 7791f15 + cf12a47 commit a1882bb

13 files changed

Lines changed: 565 additions & 108 deletions

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
packaging
22
numpy
3-
datasets==2.14.6
3+
datasets==3.6.0
44
tokenizers>=0.13.3
55
peft>=0.10.0
66
torch>=2.0.1

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
extra_require = {
1919
"multimodal": ["Pillow"],
2020
"vllm": ["vllm>=0.4.3"],
21+
"sglang": ["sglang"],
2122
"ray": ["ray>=2.22.0"],
2223
"gradio": ["gradio"],
2324
"flask": ["flask", "flask_cors"],

src/lmflow/args.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,10 @@ class InferencerArguments:
947947
default=False,
948948
metadata={"help": "whether turn on true random sampling during inference."},
949949
)
950+
return_logprob: Optional[bool] = field(
951+
default=False,
952+
metadata={"help": "whether to return log probability during inference."},
953+
)
950954
use_accelerator: Optional[bool] = field(
951955
default=None,
952956
metadata={"help": "[Deprecated] Whether to use Huggingface Accelerator instead of Deepspeed"},
@@ -994,12 +998,58 @@ class InferencerArguments:
994998
)
995999

9961000
# vllm inference args
997-
use_vllm: bool = field(default=False, metadata={"help": "Whether to use VLLM for inference, By default False."})
1001+
use_vllm: Optional[bool] = field(
1002+
default=None,
1003+
metadata={
1004+
"help": "Whether to use VLLM for inference, By default None. Deprecated, use inference_engine instead."
1005+
},
1006+
)
9981007
vllm_tensor_parallel_size: Optional[int] = field(
999-
default=1, metadata={"help": "The tensor parallel size for VLLM inference."}
1008+
default=None,
1009+
metadata={
1010+
"help": (
1011+
"The tensor parallel size for VLLM inference. Deprecated, use inference_tensor_parallel_size instead."
1012+
)
1013+
},
10001014
)
10011015
vllm_gpu_memory_utilization: Optional[float] = field(
1002-
default=0.95, metadata={"help": "The GPU memory utilization for VLLM inference."}
1016+
default=None,
1017+
metadata={
1018+
"help": (
1019+
"The GPU memory utilization for VLLM inference. "
1020+
"Deprecated, use inference_gpu_memory_utilization instead."
1021+
)
1022+
},
1023+
)
1024+
1025+
# inference engine args
1026+
inference_engine: Optional[str] = field(
1027+
default="huggingface",
1028+
metadata={
1029+
"help": "The inference engine to use, by default huggingface.",
1030+
"choices": ["huggingface", "vllm", "sglang"],
1031+
},
1032+
)
1033+
inference_tensor_parallel_size: Optional[int] = field(
1034+
default=1, metadata={"help": "The tensor parallel size for inference."}
1035+
)
1036+
inference_gpu_memory_utilization: Optional[float] = field(
1037+
default=0.95, metadata={"help": "The GPU memory utilization for inference."}
1038+
)
1039+
enable_deterministic_inference: bool = field(
1040+
default=False,
1041+
metadata={
1042+
"help": "Whether to enable deterministic inference. Only supported for SGLang inference engine currently."
1043+
},
1044+
)
1045+
attention_backend: Optional[str] = field(
1046+
default=None,
1047+
metadata={
1048+
"help": (
1049+
"The attention backend to use. Only supported for SGLang inference engine currently. "
1050+
"Please leave it as None to let SGLang automatically choose if you're not sure."
1051+
)
1052+
},
10031053
)
10041054

10051055
# Args for result saving
@@ -1023,6 +1073,44 @@ def __post_init__(self):
10231073
else:
10241074
Path(self.results_path).parent.mkdir(parents=True, exist_ok=True)
10251075

1076+
if self.use_vllm is True:
1077+
logger.warning(
1078+
"Inference engine is set to vllm. You've specified `use_vllm`. This argument is deprecated and "
1079+
"will be removed in a future version. Please use `inference_engine` instead."
1080+
)
1081+
self.inference_engine = "vllm"
1082+
1083+
if self.vllm_tensor_parallel_size is not None:
1084+
logger.warning(
1085+
"You've specified `vllm_tensor_parallel_size`. This argument is deprecated and "
1086+
"will be removed in a future version. Please use `inference_tensor_parallel_size` instead."
1087+
)
1088+
self.inference_tensor_parallel_size = self.vllm_tensor_parallel_size
1089+
1090+
if self.vllm_gpu_memory_utilization is not None:
1091+
logger.warning(
1092+
"You've specified `vllm_gpu_memory_utilization`. This argument is deprecated and "
1093+
"will be removed in a future version. Please use `inference_gpu_memory_utilization` instead."
1094+
)
1095+
self.inference_gpu_memory_utilization = self.vllm_gpu_memory_utilization
1096+
1097+
if self.inference_engine != "sglang":
1098+
if self.return_logprob:
1099+
logger.warning("`return_logprob` is only supported for SGLang inference engine currently. ")
1100+
1101+
if self.inference_engine == "sglang":
1102+
if self.enable_deterministic_inference:
1103+
if self.attention_backend is None:
1104+
self.attention_backend = "fa3"
1105+
logger.warning(
1106+
"`enable_deterministic_inference` is enabled, but `attention_backend` is not specified. "
1107+
"Using `fa3` as the attention backend by default."
1108+
)
1109+
else:
1110+
assert self.attention_backend in ["fa3", "flashinfer", "triton"], (
1111+
"Invalid attention backend. Please choose from 'fa3', 'flashinfer', or 'triton'."
1112+
)
1113+
10261114

10271115
@dataclass
10281116
class RaftAlignerArguments(TrainingArguments):

src/lmflow/models/hf_decoder_model.py

Lines changed: 98 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import hashlib
1717
import logging
1818
import os
19-
from typing import Optional, Union
19+
from typing import Literal, Optional, Union
2020

2121
import torch
2222
from peft import PeftModel
@@ -37,6 +37,7 @@
3737
)
3838
from lmflow.utils.conversation_template import PRESET_TEMPLATES
3939
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
40+
from lmflow.utils.deprecated import deprecated_args
4041
from lmflow.utils.envs import is_accelerate_env
4142
from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available
4243

@@ -273,43 +274,85 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]:
273274
# Can be list of ints or a Tensor
274275
return self.tokenizer.decode(input, **kwargs)
275276

276-
def inference(self, inputs, release_gpu: bool = False, use_vllm: bool = False, **kwargs):
277+
@deprecated_args(
278+
use_vllm={
279+
"replacement": "inference_engine",
280+
"mapper": lambda x: "vllm" if x is True else "huggingface",
281+
"message": (
282+
"use_vllm is deprecated and will be removed in a future version. "
283+
"Please use `inference_engine='vllm'` instead."
284+
),
285+
}
286+
)
287+
def inference(
288+
self,
289+
inputs: Union[str, list[str], torch.Tensor],
290+
sampling_params: Optional[Union[dict, "SamplingParams"]] = None,
291+
return_logprob: bool = False,
292+
release_gpu: bool = False,
293+
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
294+
gpu_memory_utilization: Optional[float] = None,
295+
tensor_parallel_size: Optional[int] = None,
296+
enable_deterministic_inference: bool = False,
297+
attention_backend: Optional[str] = None,
298+
**kwargs,
299+
):
277300
"""
278301
Perform generation process of the model.
279302
280303
Parameters
281304
------------
282-
inputs :
305+
inputs : Union[str, list[str], torch.Tensor]
283306
The sequence used as a prompt for the generation or as model inputs to the model.
284-
When using vllm inference, this should be a string or a list of strings.
285-
When using normal inference, this should be a tensor.
307+
When the inference engine is "vllm" or "sglang", this should be a string or a list of strings.
308+
When the inference engine is "huggingface", this should be a tensor.
309+
sampling_params : Optional[Union[dict, "SamplingParams"]], optional
310+
The sampling parameters to use, by default None.
311+
return_logprob : bool, optional
312+
Whether to return log probability during inference, by default False.
286313
release_gpu : bool, optional
287314
Whether to release the GPU resource after inference, by default False.
288-
use_vllm : bool, optional
289-
Whether to use VLLM for inference, by default False.
290-
kwargs : Optional.
291-
Keyword arguments.
315+
inference_engine : Literal["huggingface", "vllm", "sglang"], optional
316+
The inference engine to use, by default "huggingface".
317+
gpu_memory_utilization : float, optional
318+
The GPU memory utilization to use, by default None.
319+
tensor_parallel_size : int, optional
320+
The tensor parallel size to use, by default None.
321+
enable_deterministic_inference : bool, optional
322+
Whether to enable deterministic inference, by default False.
323+
attention_backend : Optional[str], optional
324+
The attention backend to use, by default None.
292325
293326
Returns
294327
------------
295328
outputs :
296329
The generated sequence output
297330
"""
331+
if isinstance(inputs, str):
332+
inputs = [inputs]
333+
298334
if not self._activated:
299335
self.activate_model_for_inference(
300-
use_vllm=use_vllm,
301-
**kwargs,
336+
inference_engine=inference_engine,
337+
gpu_memory_utilization=gpu_memory_utilization,
338+
tensor_parallel_size=tensor_parallel_size,
339+
enable_deterministic_inference=enable_deterministic_inference,
340+
attention_backend=attention_backend,
302341
)
303342

304-
if use_vllm:
305-
if not is_vllm_available():
306-
raise ImportError("vllm is not installed. Please install vllm to use VLLM inference.")
307-
res = self.__vllm_inference(inputs, **kwargs)
343+
if inference_engine == "vllm":
344+
res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params)
345+
elif inference_engine == "sglang":
346+
res = self.__sglang_inference(
347+
inputs=inputs,
348+
sampling_params=sampling_params,
349+
return_logprob=return_logprob,
350+
)
308351
else:
309-
res = self.__inference(inputs, **kwargs)
352+
res = self.__inference(inputs=inputs, **kwargs)
310353

311354
if release_gpu:
312-
self.deactivate_model_for_inference(use_vllm=use_vllm)
355+
self.deactivate_model_for_inference(inference_engine=inference_engine)
313356

314357
return res
315358

@@ -353,15 +396,14 @@ def __inference(self, inputs, *args, **kwargs):
353396

354397
def __vllm_inference(
355398
self,
356-
inputs: Union[str, list[str]],
399+
inputs: list[str],
357400
sampling_params: Optional["SamplingParams"] = None,
358-
**kwargs,
359401
) -> list[VLLMInferenceResultWithInput]:
360402
"""Perform VLLM inference process of the model.
361403
362404
Parameters
363405
----------
364-
inputs : Union[str, list[str]]
406+
inputs : list[str]
365407
Prompt(s), string or a list of strings.
366408
sampling_params : Optional[SamplingParams], optional
367409
vllm SamplingParams object, by default None.
@@ -383,6 +425,7 @@ def __vllm_inference(
383425
sampling_params=sampling_params,
384426
use_tqdm=True,
385427
)
428+
# TODO: unified lmflow sample format
386429
final_output = []
387430
for output in vllm_outputs:
388431
if sampling_params.detokenize:
@@ -394,54 +437,39 @@ def __vllm_inference(
394437

395438
return final_output
396439

397-
def prepare_inputs_for_inference(
440+
def __sglang_inference(
398441
self,
399-
dataset: Dataset,
400-
apply_chat_template: bool = True,
401-
enable_distributed_inference: bool = False,
402-
use_vllm: bool = False,
403-
**kwargs,
404-
) -> Union[list[str], "ray.data.Dataset", dict[str, torch.Tensor]]:
405-
"""
406-
Prepare inputs for inference.
407-
408-
Parameters
409-
------------
410-
dataset : lmflow.datasets.Dataset.
411-
The dataset used for inference.
412-
413-
args : Optional.
414-
Positional arguments.
415-
416-
kwargs : Optional.
417-
Keyword arguments.
418-
419-
Returns
420-
------------
421-
outputs :
422-
The prepared inputs for inference.
423-
"""
424-
if use_vllm:
425-
if not is_ray_available() and enable_distributed_inference:
426-
raise ImportError("ray is not installed. Please install ray to use distributed vllm inference.")
427-
inference_inputs = self.__prepare_inputs_for_vllm_inference(
428-
dataset=dataset,
429-
apply_chat_template=apply_chat_template,
430-
enable_distributed_inference=enable_distributed_inference,
431-
)
432-
else:
433-
inference_inputs = self.__prepare_inputs_for_inference(
434-
dataset,
435-
apply_chat_template=apply_chat_template,
436-
enable_distributed_inference=enable_distributed_inference,
437-
)
438-
439-
return inference_inputs
440-
441-
def __prepare_inputs_for_vllm_inference(
442+
inputs: list[str],
443+
sampling_params: Optional[dict] = None,
444+
return_logprob: bool = False,
445+
):
446+
"""Perform SGLang inference process of the model."""
447+
sglang_outputs = self.backend_model_for_inference.generate(
448+
prompt=inputs,
449+
sampling_params=sampling_params,
450+
return_logprob=return_logprob,
451+
)
452+
# TODO: unified lmflow sample format
453+
for idx, output in enumerate(sglang_outputs):
454+
output["input"] = inputs[idx]
455+
output["output"] = output.pop("text")
456+
return sglang_outputs
457+
458+
@deprecated_args(
459+
use_vllm={
460+
"replacement": "inference_engine",
461+
"mapper": lambda x: "vllm" if x is True else "huggingface",
462+
"message": (
463+
"use_vllm is deprecated and will be removed in a future version. "
464+
"Please use `inference_engine='vllm'` instead."
465+
),
466+
}
467+
)
468+
def prepare_inputs_for_inference(
442469
self,
443470
dataset: Dataset,
444471
apply_chat_template: bool = True,
472+
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
445473
enable_distributed_inference: bool = False,
446474
) -> Union[list[str], "ray.data.Dataset"]:
447475
if dataset.get_type() == "text_only":
@@ -498,6 +526,7 @@ def preprocess_conversation(sample):
498526

499527
return sample_out
500528

529+
# TODO: investigate performance issue
501530
dataset = dataset.map(
502531
preprocess_conversation,
503532
num_proc=dataset.data_args.preprocessing_num_workers,
@@ -517,19 +546,16 @@ def preprocess_conversation(sample):
517546

518547
inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0]
519548

520-
if enable_distributed_inference:
549+
if inference_engine == "vllm" and enable_distributed_inference:
521550
inference_inputs = ray.data.from_items(
522551
inference_inputs
523552
) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}
524553

525-
return inference_inputs
554+
if inference_engine == "sglang" and self.tokenizer.bos_token:
555+
# in consistent with sglang bench_serving.py demo
556+
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]
526557

527-
def __prepare_inputs_for_inference(
528-
self,
529-
dataset: Dataset,
530-
**kwargs,
531-
):
532-
raise NotImplementedError("prepare_inputs_for_inference is not implemented")
558+
return inference_inputs
533559

534560
def merge_lora_weights(self):
535561
if self.model_args.use_lora and not self.model_args.use_qlora:

0 commit comments

Comments
 (0)