Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
packaging
numpy
datasets==2.14.6
datasets==3.6.0
tokenizers>=0.13.3
peft>=0.10.0
torch>=2.0.1
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
extra_require = {
"multimodal": ["Pillow"],
"vllm": ["vllm>=0.4.3"],
"sglang": ["sglang"],
"ray": ["ray>=2.22.0"],
"gradio": ["gradio"],
"flask": ["flask", "flask_cors"],
Expand Down
94 changes: 91 additions & 3 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,10 @@ class InferencerArguments:
default=False,
metadata={"help": "whether turn on true random sampling during inference."},
)
return_logprob: Optional[bool] = field(
default=False,
metadata={"help": "whether to return log probability during inference."},
)
use_accelerator: Optional[bool] = field(
default=None,
metadata={"help": "[Deprecated] Whether to use Huggingface Accelerator instead of Deepspeed"},
Expand Down Expand Up @@ -994,12 +998,58 @@ class InferencerArguments:
)

# vllm inference args
use_vllm: bool = field(default=False, metadata={"help": "Whether to use VLLM for inference, By default False."})
use_vllm: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to use VLLM for inference, By default None. Deprecated, use inference_engine instead."
},
)
vllm_tensor_parallel_size: Optional[int] = field(
default=1, metadata={"help": "The tensor parallel size for VLLM inference."}
default=None,
metadata={
"help": (
"The tensor parallel size for VLLM inference. Deprecated, use inference_tensor_parallel_size instead."
)
},
)
vllm_gpu_memory_utilization: Optional[float] = field(
default=0.95, metadata={"help": "The GPU memory utilization for VLLM inference."}
default=None,
metadata={
"help": (
"The GPU memory utilization for VLLM inference. "
"Deprecated, use inference_gpu_memory_utilization instead."
)
},
)

# inference engine args
inference_engine: Optional[str] = field(
default="huggingface",
metadata={
"help": "The inference engine to use, by default huggingface.",
"choices": ["huggingface", "vllm", "sglang"],
},
)
inference_tensor_parallel_size: Optional[int] = field(
default=1, metadata={"help": "The tensor parallel size for inference."}
)
inference_gpu_memory_utilization: Optional[float] = field(
default=0.95, metadata={"help": "The GPU memory utilization for inference."}
)
enable_deterministic_inference: bool = field(
default=False,
metadata={
"help": "Whether to enable deterministic inference. Only supported for SGLang inference engine currently."
},
)
attention_backend: Optional[str] = field(
default=None,
metadata={
"help": (
"The attention backend to use. Only supported for SGLang inference engine currently. "
"Please leave it as None to let SGLang automatically choose if you're not sure."
)
},
)

# Args for result saving
Expand All @@ -1023,6 +1073,44 @@ def __post_init__(self):
else:
Path(self.results_path).parent.mkdir(parents=True, exist_ok=True)

if self.use_vllm is True:
logger.warning(
"Inference engine is set to vllm. You've specified `use_vllm`. This argument is deprecated and "
"will be removed in a future version. Please use `inference_engine` instead."
)
self.inference_engine = "vllm"

if self.vllm_tensor_parallel_size is not None:
logger.warning(
"You've specified `vllm_tensor_parallel_size`. This argument is deprecated and "
"will be removed in a future version. Please use `inference_tensor_parallel_size` instead."
)
self.inference_tensor_parallel_size = self.vllm_tensor_parallel_size

if self.vllm_gpu_memory_utilization is not None:
logger.warning(
"You've specified `vllm_gpu_memory_utilization`. This argument is deprecated and "
"will be removed in a future version. Please use `inference_gpu_memory_utilization` instead."
)
self.inference_gpu_memory_utilization = self.vllm_gpu_memory_utilization

if self.inference_engine != "sglang":
if self.return_logprob:
logger.warning("`return_logprob` is only supported for SGLang inference engine currently. ")

if self.inference_engine == "sglang":
if self.enable_deterministic_inference:
if self.attention_backend is None:
self.attention_backend = "fa3"
logger.warning(
"`enable_deterministic_inference` is enabled, but `attention_backend` is not specified. "
"Using `fa3` as the attention backend by default."
)
else:
assert self.attention_backend in ["fa3", "flashinfer", "triton"], (
"Invalid attention backend. Please choose from 'fa3', 'flashinfer', or 'triton'."
)


@dataclass
class RaftAlignerArguments(TrainingArguments):
Expand Down
170 changes: 98 additions & 72 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import hashlib
import logging
import os
from typing import Optional, Union
from typing import Literal, Optional, Union

import torch
from peft import PeftModel
Expand All @@ -37,6 +37,7 @@
)
from lmflow.utils.conversation_template import PRESET_TEMPLATES
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
from lmflow.utils.deprecated import deprecated_args
from lmflow.utils.envs import is_accelerate_env
from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available

Expand Down Expand Up @@ -273,43 +274,85 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]:
# Can be list of ints or a Tensor
return self.tokenizer.decode(input, **kwargs)

def inference(self, inputs, release_gpu: bool = False, use_vllm: bool = False, **kwargs):
@deprecated_args(
use_vllm={
"replacement": "inference_engine",
"mapper": lambda x: "vllm" if x is True else "huggingface",
"message": (
"use_vllm is deprecated and will be removed in a future version. "
"Please use `inference_engine='vllm'` instead."
),
}
)
def inference(
self,
inputs: Union[str, list[str], torch.Tensor],
sampling_params: Optional[Union[dict, "SamplingParams"]] = None,
return_logprob: bool = False,
release_gpu: bool = False,
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
gpu_memory_utilization: Optional[float] = None,
tensor_parallel_size: Optional[int] = None,
enable_deterministic_inference: bool = False,
attention_backend: Optional[str] = None,
**kwargs,
):
"""
Perform generation process of the model.

Parameters
------------
inputs :
inputs : Union[str, list[str], torch.Tensor]
The sequence used as a prompt for the generation or as model inputs to the model.
When using vllm inference, this should be a string or a list of strings.
When using normal inference, this should be a tensor.
When the inference engine is "vllm" or "sglang", this should be a string or a list of strings.
When the inference engine is "huggingface", this should be a tensor.
sampling_params : Optional[Union[dict, "SamplingParams"]], optional
The sampling parameters to use, by default None.
return_logprob : bool, optional
Whether to return log probability during inference, by default False.
release_gpu : bool, optional
Whether to release the GPU resource after inference, by default False.
use_vllm : bool, optional
Whether to use VLLM for inference, by default False.
kwargs : Optional.
Keyword arguments.
inference_engine : Literal["huggingface", "vllm", "sglang"], optional
The inference engine to use, by default "huggingface".
gpu_memory_utilization : float, optional
The GPU memory utilization to use, by default None.
tensor_parallel_size : int, optional
The tensor parallel size to use, by default None.
enable_deterministic_inference : bool, optional
Whether to enable deterministic inference, by default False.
attention_backend : Optional[str], optional
The attention backend to use, by default None.

Returns
------------
outputs :
The generated sequence output
"""
if isinstance(inputs, str):
inputs = [inputs]

if not self._activated:
self.activate_model_for_inference(
use_vllm=use_vllm,
**kwargs,
inference_engine=inference_engine,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
enable_deterministic_inference=enable_deterministic_inference,
attention_backend=attention_backend,
)

if use_vllm:
if not is_vllm_available():
raise ImportError("vllm is not installed. Please install vllm to use VLLM inference.")
res = self.__vllm_inference(inputs, **kwargs)
if inference_engine == "vllm":
res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params)
elif inference_engine == "sglang":
res = self.__sglang_inference(
inputs=inputs,
sampling_params=sampling_params,
return_logprob=return_logprob,
)
else:
res = self.__inference(inputs, **kwargs)
res = self.__inference(inputs=inputs, **kwargs)

if release_gpu:
self.deactivate_model_for_inference(use_vllm=use_vllm)
self.deactivate_model_for_inference(inference_engine=inference_engine)

return res

Expand Down Expand Up @@ -353,15 +396,14 @@ def __inference(self, inputs, *args, **kwargs):

def __vllm_inference(
self,
inputs: Union[str, list[str]],
inputs: list[str],
sampling_params: Optional["SamplingParams"] = None,
**kwargs,
) -> list[VLLMInferenceResultWithInput]:
"""Perform VLLM inference process of the model.

Parameters
----------
inputs : Union[str, list[str]]
inputs : list[str]
Prompt(s), string or a list of strings.
sampling_params : Optional[SamplingParams], optional
vllm SamplingParams object, by default None.
Expand All @@ -383,6 +425,7 @@ def __vllm_inference(
sampling_params=sampling_params,
use_tqdm=True,
)
# TODO: unified lmflow sample format
final_output = []
for output in vllm_outputs:
if sampling_params.detokenize:
Expand All @@ -394,54 +437,39 @@ def __vllm_inference(

return final_output

def prepare_inputs_for_inference(
def __sglang_inference(
self,
dataset: Dataset,
apply_chat_template: bool = True,
enable_distributed_inference: bool = False,
use_vllm: bool = False,
**kwargs,
) -> Union[list[str], "ray.data.Dataset", dict[str, torch.Tensor]]:
"""
Prepare inputs for inference.

Parameters
------------
dataset : lmflow.datasets.Dataset.
The dataset used for inference.

args : Optional.
Positional arguments.

kwargs : Optional.
Keyword arguments.

Returns
------------
outputs :
The prepared inputs for inference.
"""
if use_vllm:
if not is_ray_available() and enable_distributed_inference:
raise ImportError("ray is not installed. Please install ray to use distributed vllm inference.")
inference_inputs = self.__prepare_inputs_for_vllm_inference(
dataset=dataset,
apply_chat_template=apply_chat_template,
enable_distributed_inference=enable_distributed_inference,
)
else:
inference_inputs = self.__prepare_inputs_for_inference(
dataset,
apply_chat_template=apply_chat_template,
enable_distributed_inference=enable_distributed_inference,
)

return inference_inputs

def __prepare_inputs_for_vllm_inference(
inputs: list[str],
sampling_params: Optional[dict] = None,
return_logprob: bool = False,
):
"""Perform SGLang inference process of the model."""
sglang_outputs = self.backend_model_for_inference.generate(
prompt=inputs,
sampling_params=sampling_params,
return_logprob=return_logprob,
)
# TODO: unified lmflow sample format
for idx, output in enumerate(sglang_outputs):
output["input"] = inputs[idx]
output["output"] = output.pop("text")
return sglang_outputs

@deprecated_args(
use_vllm={
"replacement": "inference_engine",
"mapper": lambda x: "vllm" if x is True else "huggingface",
"message": (
"use_vllm is deprecated and will be removed in a future version. "
"Please use `inference_engine='vllm'` instead."
),
}
)
def prepare_inputs_for_inference(
self,
dataset: Dataset,
apply_chat_template: bool = True,
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
enable_distributed_inference: bool = False,
) -> Union[list[str], "ray.data.Dataset"]:
if dataset.get_type() == "text_only":
Expand Down Expand Up @@ -498,6 +526,7 @@ def preprocess_conversation(sample):

return sample_out

# TODO: investigate performance issue
dataset = dataset.map(
preprocess_conversation,
num_proc=dataset.data_args.preprocessing_num_workers,
Expand All @@ -517,19 +546,16 @@ def preprocess_conversation(sample):

inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0]

if enable_distributed_inference:
if inference_engine == "vllm" and enable_distributed_inference:
inference_inputs = ray.data.from_items(
inference_inputs
) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}

return inference_inputs
if inference_engine == "sglang" and self.tokenizer.bos_token:
# in consistent with sglang bench_serving.py demo
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]

def __prepare_inputs_for_inference(
self,
dataset: Dataset,
**kwargs,
):
raise NotImplementedError("prepare_inputs_for_inference is not implemented")
return inference_inputs

def merge_lora_weights(self):
if self.model_args.use_lora and not self.model_args.use_qlora:
Expand Down
Loading