From 5da1d35c227b94ee44b1b71b60125f74ca15fb86 Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Sun, 23 Nov 2025 15:13:53 +0800 Subject: [PATCH 1/6] [feature] add sglang support --- setup.py | 1 + src/lmflow/args.py | 51 +++++- src/lmflow/models/hf_decoder_model.py | 149 +++++++++--------- src/lmflow/models/hf_model_mixin.py | 62 +++++--- src/lmflow/pipeline/sglang_inferencer.py | 110 +++++++++++++ src/lmflow/pipeline/vllm_inferencer.py | 20 +-- src/lmflow/utils/deprecated.py | 77 +++++++++ src/lmflow/utils/versioning.py | 4 + tests/datasets/conftest.py | 13 ++ .../test_memory_safe_vllm_inferencer.py | 4 +- tests/pipeline/test_sglang_infernecer.py | 35 ++++ 11 files changed, 419 insertions(+), 107 deletions(-) create mode 100644 src/lmflow/pipeline/sglang_inferencer.py create mode 100644 src/lmflow/utils/deprecated.py create mode 100644 tests/datasets/conftest.py create mode 100644 tests/pipeline/test_sglang_infernecer.py diff --git a/setup.py b/setup.py index 1a8393607..e810b8a2b 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ extra_require = { "multimodal": ["Pillow"], "vllm": ["vllm>=0.4.3"], + "sglang": ["sglang"], "ray": ["ray>=2.22.0"], "gradio": ["gradio"], "flask": ["flask", "flask_cors"], diff --git a/src/lmflow/args.py b/src/lmflow/args.py index e8cb15781..1a1228ff0 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -994,14 +994,36 @@ class InferencerArguments: ) # vllm inference args - use_vllm: bool = field(default=False, metadata={"help": "Whether to use VLLM for inference, By default False."}) + use_vllm: Optional[bool] = field( + default=None, + metadata={"help": "Whether to use VLLM for inference, By default None. Deprecated, use inference_engine instead."} + ) vllm_tensor_parallel_size: Optional[int] = field( - default=1, metadata={"help": "The tensor parallel size for VLLM inference."} + default=None, + metadata={"help": "The tensor parallel size for VLLM inference. Deprecated, use inference_tensor_parallel_size instead."} ) vllm_gpu_memory_utilization: Optional[float] = field( - default=0.95, metadata={"help": "The GPU memory utilization for VLLM inference."} + default=None, + metadata={"help": "The GPU memory utilization for VLLM inference. Deprecated, use inference_gpu_memory_utilization instead."} ) - + + # inference engine args + inference_engine: Optional[str] = field( + default="huggingface", + metadata={ + "help": "The inference engine to use, by default huggingface.", + "choices": ["huggingface", "vllm", "sglang"] + } + ) + inference_tensor_parallel_size: Optional[int] = field( + default=1, + metadata={"help": "The tensor parallel size for inference."} + ) + inference_gpu_memory_utilization: Optional[float] = field( + default=0.95, + metadata={"help": "The GPU memory utilization for inference."} + ) + # Args for result saving save_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."}) results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."}) @@ -1022,6 +1044,27 @@ def __post_init__(self): raise ValueError("The results_path must be a json file.") else: Path(self.results_path).parent.mkdir(parents=True, exist_ok=True) + + if self.use_vllm is True: + logger.warning( + "Inference engine is set to vllm. You've specified `use_vllm`. This argument is deprecated and " + "will be removed in a future version. Please use `inference_engine` instead." + ) + self.inference_engine = "vllm" + + if self.vllm_tensor_parallel_size is not None: + logger.warning( + "You've specified `vllm_tensor_parallel_size`. This argument is deprecated and " + "will be removed in a future version. Please use `inference_tensor_parallel_size` instead." + ) + self.inference_tensor_parallel_size = self.vllm_tensor_parallel_size + + if self.vllm_gpu_memory_utilization is not None: + logger.warning( + "You've specified `vllm_gpu_memory_utilization`. This argument is deprecated and " + "will be removed in a future version. Please use `inference_gpu_memory_utilization` instead." + ) + self.inference_gpu_memory_utilization = self.vllm_gpu_memory_utilization @dataclass diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 25b4aae9f..a4ec719fd 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -16,7 +16,7 @@ import hashlib import logging import os -from typing import Optional, Union +from typing import Literal, Optional, Union import torch from peft import PeftModel @@ -37,6 +37,7 @@ ) from lmflow.utils.conversation_template import PRESET_TEMPLATES from lmflow.utils.data_utils import VLLMInferenceResultWithInput +from lmflow.utils.deprecated import deprecated_args from lmflow.utils.envs import is_accelerate_env from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available @@ -273,43 +274,67 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]: # Can be list of ints or a Tensor return self.tokenizer.decode(input, **kwargs) - def inference(self, inputs, release_gpu: bool = False, use_vllm: bool = False, **kwargs): + @deprecated_args( + use_vllm={ + 'replacement': 'inference_engine', + 'mapper': lambda x: 'vllm' if x is True else 'huggingface', + 'message': "use_vllm is deprecated and will be removed in a future version. Please use `inference_engine='vllm'` instead." + } + ) + def inference( + self, + inputs: Union[str, list[str], torch.Tensor], + sampling_params: Optional[Union[dict, "SamplingParams"]] = None, + release_gpu: bool = False, + inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", + gpu_memory_utilization: Optional[float] = None, + tensor_parallel_size: Optional[int] = None, + **kwargs + ): """ Perform generation process of the model. Parameters ------------ - inputs : + inputs : Union[str, list[str], torch.Tensor] The sequence used as a prompt for the generation or as model inputs to the model. - When using vllm inference, this should be a string or a list of strings. - When using normal inference, this should be a tensor. + When the inference engine is "vllm" or "sglang", this should be a string or a list of strings. + When the inference engine is "huggingface", this should be a tensor. + sampling_params : Optional[Union[dict, "SamplingParams"]], optional + The sampling parameters to use, by default None. release_gpu : bool, optional Whether to release the GPU resource after inference, by default False. - use_vllm : bool, optional - Whether to use VLLM for inference, by default False. - kwargs : Optional. - Keyword arguments. + inference_engine : Literal["huggingface", "vllm", "sglang"], optional + The inference engine to use, by default "huggingface". + gpu_memory_utilization : float, optional + The GPU memory utilization to use, by default None. + tensor_parallel_size : int, optional + The tensor parallel size to use, by default None. Returns ------------ outputs : The generated sequence output """ + if isinstance(inputs, str): + inputs = [inputs] + if not self._activated: self.activate_model_for_inference( - use_vllm=use_vllm, - **kwargs, + inference_engine=inference_engine, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, ) - - if use_vllm: - if not is_vllm_available(): - raise ImportError("vllm is not installed. Please install vllm to use VLLM inference.") - res = self.__vllm_inference(inputs, **kwargs) + + if inference_engine == "vllm": + res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params) + elif inference_engine == "sglang": + res = self.__sglang_inference(inputs=inputs, sampling_params=sampling_params) else: - res = self.__inference(inputs, **kwargs) + res = self.__inference(inputs=inputs, **kwargs) if release_gpu: - self.deactivate_model_for_inference(use_vllm=use_vllm) + self.deactivate_model_for_inference(inference_engine=inference_engine) return res @@ -353,15 +378,14 @@ def __inference(self, inputs, *args, **kwargs): def __vllm_inference( self, - inputs: Union[str, list[str]], + inputs: list[str], sampling_params: Optional["SamplingParams"] = None, - **kwargs, ) -> list[VLLMInferenceResultWithInput]: """Perform VLLM inference process of the model. Parameters ---------- - inputs : Union[str, list[str]] + inputs : list[str] Prompt(s), string or a list of strings. sampling_params : Optional[SamplingParams], optional vllm SamplingParams object, by default None. @@ -383,6 +407,7 @@ def __vllm_inference( sampling_params=sampling_params, use_tqdm=True, ) + # TODO: unified lmflow sample format final_output = [] for output in vllm_outputs: if sampling_params.detokenize: @@ -393,55 +418,36 @@ def __vllm_inference( final_output.append({"input": output.prompt, "output": output_list}) return final_output - - def prepare_inputs_for_inference( + + def __sglang_inference( self, - dataset: Dataset, - apply_chat_template: bool = True, - enable_distributed_inference: bool = False, - use_vllm: bool = False, - **kwargs, - ) -> Union[list[str], "ray.data.Dataset", dict[str, torch.Tensor]]: - """ - Prepare inputs for inference. - - Parameters - ------------ - dataset : lmflow.datasets.Dataset. - The dataset used for inference. - - args : Optional. - Positional arguments. - - kwargs : Optional. - Keyword arguments. - - Returns - ------------ - outputs : - The prepared inputs for inference. + inputs: list[str], + sampling_params: Optional[dict] = None, + ): + """Perform SGLang inference process of the model. """ - if use_vllm: - if not is_ray_available() and enable_distributed_inference: - raise ImportError("ray is not installed. Please install ray to use distributed vllm inference.") - inference_inputs = self.__prepare_inputs_for_vllm_inference( - dataset=dataset, - apply_chat_template=apply_chat_template, - enable_distributed_inference=enable_distributed_inference, - ) - else: - inference_inputs = self.__prepare_inputs_for_inference( - dataset, - apply_chat_template=apply_chat_template, - enable_distributed_inference=enable_distributed_inference, - ) - - return inference_inputs - - def __prepare_inputs_for_vllm_inference( + sglang_outputs = self.backend_model_for_inference.generate( + prompt=inputs, + sampling_params=sampling_params, + ) + # TODO: unified lmflow sample format + for idx, output in enumerate(sglang_outputs): + output['input'] = inputs[idx] + output['output'] = output.pop('text') + return sglang_outputs + + @deprecated_args( + use_vllm={ + 'replacement': 'inference_engine', + 'mapper': lambda x: 'vllm' if x is True else 'huggingface', + 'message': "use_vllm is deprecated and will be removed in a future version. Please use `inference_engine='vllm'` instead." + } + ) + def prepare_inputs_for_inference( self, dataset: Dataset, apply_chat_template: bool = True, + inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", enable_distributed_inference: bool = False, ) -> Union[list[str], "ray.data.Dataset"]: if dataset.get_type() == "text_only": @@ -517,20 +523,17 @@ def preprocess_conversation(sample): inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0] - if enable_distributed_inference: + if inference_engine == "vllm" and enable_distributed_inference: inference_inputs = ray.data.from_items( inference_inputs ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])} + + if inference_engine == "sglang" and self.tokenizer.bos_token: + # in consistent with sglang bench_serving.py demo + inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs] return inference_inputs - def __prepare_inputs_for_inference( - self, - dataset: Dataset, - **kwargs, - ): - raise NotImplementedError("prepare_inputs_for_inference is not implemented") - def merge_lora_weights(self): if self.model_args.use_lora and not self.model_args.use_qlora: self.get_backend_model().merge_and_unload() diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index 42ff9b398..6ea1cf1c5 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -4,7 +4,7 @@ import gc import logging from contextlib import nullcontext -from typing import Optional, Union +from typing import Literal, Optional, Union import torch from peft import LoraConfig, PeftModel, TaskType, get_peft_model @@ -25,11 +25,7 @@ from lmflow.models.base_model import BaseModel from lmflow.utils.constants import LMFLOW_LORA_TARGET_MODULES_MAPPING from lmflow.utils.envs import is_accelerate_env -from lmflow.utils.versioning import is_deepspeed_available, is_vllm_available - -if is_vllm_available(): - from vllm import LLM - from vllm.distributed.parallel_state import destroy_model_parallel +from lmflow.utils.versioning import is_deepspeed_available, is_vllm_available, is_sglang_available logger = logging.getLogger(__name__) @@ -451,20 +447,41 @@ def __prepare_model_for_inference( def __prepare_model_for_vllm_inference( self, model_args: ModelArguments, - vllm_gpu_memory_utilization: float, - vllm_tensor_parallel_size: int, + gpu_memory_utilization: float, + tensor_parallel_size: int, ): if not is_vllm_available(): raise ImportError('VLLM is not available. Please install via `pip install -e ".[vllm]"`.') + + from vllm import LLM self.backend_model_for_inference = LLM( model=model_args.model_name_or_path, tokenizer=model_args.model_name_or_path, dtype=model_args.torch_dtype if model_args.torch_dtype else "auto", load_format="auto", - gpu_memory_utilization=vllm_gpu_memory_utilization, - tensor_parallel_size=vllm_tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + ) + + def __prepare_model_for_sglang_inference( + self, + model_args: ModelArguments, + gpu_memory_utilization: Optional[float] = None, + tensor_parallel_size: Optional[int] = None, + ): + if not is_sglang_available(): + raise ImportError('SGLang is not available. Please install via `pip install -e ".[sglang]"`.') + + from sglang.srt.entrypoints.engine import Engine + from sglang.srt.server_args import ServerArgs + + sgl_server_args = ServerArgs( + model_path=model_args.model_name_or_path, + mem_fraction_static=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, ) + self.backend_model_for_inference = Engine(server_args=sgl_server_args) def __fix_special_tokens(self): # old models/tokenizers may not have these attributes, fixing @@ -490,18 +507,25 @@ def __fix_special_tokens(self): def activate_model_for_inference( self, - use_vllm: bool = False, - **kwargs, + inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", + gpu_memory_utilization: Optional[float] = None, + tensor_parallel_size: Optional[int] = None, ): if self._activated: logger.warning("You are trying to activate the model for inference, but it is already activated.") return - if use_vllm: + if inference_engine == "vllm": self.__prepare_model_for_vllm_inference( model_args=self.model_args, - vllm_gpu_memory_utilization=kwargs.get("vllm_gpu_memory_utilization"), - vllm_tensor_parallel_size=kwargs.get("vllm_tensor_parallel_size"), + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + ) + elif inference_engine == "sglang": + self.__prepare_model_for_sglang_inference( + model_args=self.model_args, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, ) else: self.__prepare_model_for_inference( @@ -513,7 +537,7 @@ def activate_model_for_inference( def deactivate_model_for_inference( self, - use_vllm: bool = False, + inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", ): """Deactivate the model and release the resources. @@ -526,15 +550,17 @@ def deactivate_model_for_inference( logger.warning("You are trying to deactivate the model for inference, but it is already deactivated.") return - if use_vllm: + if inference_engine == "vllm": + from vllm.distributed.parallel_state import destroy_model_parallel destroy_model_parallel() del self.backend_model_for_inference.llm_engine.model_executor.driver_worker del self.backend_model_for_inference gc.collect() torch.cuda.empty_cache() + elif inference_engine == "sglang": + pass else: self.backend_model.to("cpu") - pass self._activated = False diff --git a/src/lmflow/pipeline/sglang_inferencer.py b/src/lmflow/pipeline/sglang_inferencer.py new file mode 100644 index 000000000..712413d1d --- /dev/null +++ b/src/lmflow/pipeline/sglang_inferencer.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. +import json +import logging + +from typing import Optional, Union + +from transformers import AutoTokenizer + +from lmflow.args import ( + DatasetArguments, + InferencerArguments, + ModelArguments, +) +from lmflow.datasets import Dataset +from lmflow.models.hf_decoder_model import HFDecoderModel +from lmflow.pipeline.base_pipeline import BasePipeline +from lmflow.utils.versioning import is_sglang_available + +logger = logging.getLogger(__name__) + + +if is_sglang_available(): + pass +else: + raise ImportError("SGLang is not available, please install sglang using `pip install -e .[sglang]`.") + + +class SGLangInferencer(BasePipeline): + def __init__( + self, + model_args: ModelArguments, + data_args: DatasetArguments, + inferencer_args: InferencerArguments, + ): + assert inferencer_args.inference_engine == "sglang" + self.model_args = model_args + self.data_args = data_args + self.inferencer_args = inferencer_args + self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id + self.sampling_params = self._parse_args_to_sampling_params(inferencer_args) + + def _parse_args_to_sampling_params( + self, + inference_args: InferencerArguments, + ) -> dict: + sampling_params = { + "use_beam_search": inference_args.use_beam_search, + "n": inference_args.num_output_sequences, + "temperature": inference_args.temperature + 1e-6, + "max_tokens": inference_args.max_new_tokens, + "seed": inference_args.random_seed, + "top_p": inference_args.top_p, + "top_k": inference_args.top_k, + "stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids, + } + + return sampling_params + + def inference( + self, + model: HFDecoderModel, + dataset: Dataset, + release_gpu: bool = False, + inference_args: Optional[InferencerArguments] = None, + ): + if inference_args: + logger.warning("Overriding the default inference arguments with the provided arguments in .inference()") + sampling_params = self._parse_args_to_sampling_params(inference_args) + else: + sampling_params = self.sampling_params + + model_input = model.prepare_inputs_for_inference( + dataset=dataset, + apply_chat_template=self.inferencer_args.apply_chat_template, + inference_engine="sglang", + ) + + outputs = model.inference( + inputs=model_input, + sampling_params=sampling_params, + release_gpu=release_gpu, + inference_engine="sglang", + gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, + tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size, + ) + + if self.inferencer_args.save_results: + self.save_inference_results(outputs, self.inferencer_args.results_path) + + return outputs + + def save_inference_results( + self, + outputs: Union[list[list[str]], list[list[list[int]]]], + save_file_path: str, + ): + with open(save_file_path, "w", encoding="utf-8") as f: + json.dump(outputs, f, ensure_ascii=False, indent=4) + + logger.info(f"Inference results are saved to {save_file_path}.") + + def load_inference_results( + self, + results_path: str, + ) -> Union[list[list[str]], list[list[list[int]]]]: + with open(results_path) as f: + results = json.load(f) + + return results diff --git a/src/lmflow/pipeline/vllm_inferencer.py b/src/lmflow/pipeline/vllm_inferencer.py index d2905c34a..9873598f8 100644 --- a/src/lmflow/pipeline/vllm_inferencer.py +++ b/src/lmflow/pipeline/vllm_inferencer.py @@ -180,8 +180,8 @@ def _inference( sampling_params=sampling_params, release_gpu=release_gpu, use_vllm=True, - vllm_gpu_memory_utilization=self.inferencer_args.vllm_gpu_memory_utilization, - vllm_tensor_parallel_size=self.inferencer_args.vllm_tensor_parallel_size, + gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, + tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size, ) return outputs @@ -201,7 +201,7 @@ def _distributed_inference( def scheduling_strategy_fn(): # One bundle per tensor parallel worker pg = ray.util.placement_group( - [{"GPU": 1, "CPU": 1}] * self.inferencer_args.vllm_tensor_parallel_size, + [{"GPU": 1, "CPU": 1}] * self.inferencer_args.inference_tensor_parallel_size, strategy="STRICT_PACK", ) return dict( @@ -209,7 +209,7 @@ def scheduling_strategy_fn(): ) resources_kwarg: dict[str, Any] = {} - if self.inferencer_args.vllm_tensor_parallel_size == 1: + if self.inferencer_args.inference_tensor_parallel_size == 1: # For tensor_parallel_size == 1, we simply set num_gpus=1. resources_kwarg["num_gpus"] = 1 else: @@ -225,15 +225,15 @@ def __init__( self, model: HFDecoderModel, sampling_params: SamplingParams, - vllm_gpu_memory_utilization: float, - vllm_tensor_parallel_size: int, + gpu_memory_utilization: float, + tensor_parallel_size: int, release_gpu: bool = False, ): self.model = copy.deepcopy(model) self.model.activate_model_for_inference( use_vllm=True, - vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, - vllm_tensor_parallel_size=vllm_tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, ) self.sampling_params = sampling_params self.release_gpu = release_gpu @@ -260,8 +260,8 @@ def __call__(self, batch: dict[str, np.ndarray]): fn_constructor_kwargs={ "model": model, "sampling_params": sampling_params, - "vllm_gpu_memory_utilization": self.inferencer_args.vllm_gpu_memory_utilization, - "vllm_tensor_parallel_size": self.inferencer_args.vllm_tensor_parallel_size, + "gpu_memory_utilization": self.inferencer_args.inference_gpu_memory_utilization, + "tensor_parallel_size": self.inferencer_args.inference_tensor_parallel_size, "release_gpu": release_gpu, }, **resources_kwarg, diff --git a/src/lmflow/utils/deprecated.py b/src/lmflow/utils/deprecated.py new file mode 100644 index 000000000..c2ee06648 --- /dev/null +++ b/src/lmflow/utils/deprecated.py @@ -0,0 +1,77 @@ +""" +Utilities for handling deprecated APIs and maintaining backwards compatibility. +""" + +import functools +import inspect +import warnings +from typing import Any, Callable, Dict + +__all__ = ['deprecated_args'] + + +def deprecated_args(**deprecated_params: Dict[str, Any]): + """ + Decorator to handle deprecated function arguments. + + Args: + **deprecated_params: Mapping of deprecated argument names to their configuration. + Each value should be a dict with: + - 'replacement': Name of the new argument (optional) + - 'mapper': Function to map old value to new value (optional) + - 'message': Custom deprecation message (optional) + + Example: + @deprecated_args( + use_vllm={ + 'replacement': 'inference_engine', + 'mapper': lambda x: 'vllm' if x else 'huggingface', + 'message': "use_vllm is deprecated. Use inference_engine='vllm' instead." + } + ) + def my_function(inference_engine='huggingface', **kwargs): + pass + """ + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Get function signature to handle both args and kwargs + sig = inspect.signature(func) + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + + # Check for deprecated arguments in kwargs + for old_arg, config in deprecated_params.items(): + if old_arg in kwargs: + old_value = kwargs.pop(old_arg) + + # Build deprecation message + if 'message' in config: + message = config['message'] + else: + replacement = config.get('replacement', 'a different argument') + message = ( + f"'{old_arg}' is deprecated and will be removed in a future version. " + f"Please use '{replacement}' instead." + ) + + warnings.warn(message, DeprecationWarning, stacklevel=2) + + # Map old value to new argument if specified + if 'replacement' in config: + new_arg = config['replacement'] + + # Apply mapper function if provided + if 'mapper' in config: + new_value = config['mapper'](old_value) + else: + new_value = old_value + + # Only set the new argument if it wasn't already provided + if new_arg not in kwargs: + kwargs[new_arg] = new_value + + return func(*args, **kwargs) + + return wrapper + return decorator \ No newline at end of file diff --git a/src/lmflow/utils/versioning.py b/src/lmflow/utils/versioning.py index a08bf119f..0b52798fd 100644 --- a/src/lmflow/utils/versioning.py +++ b/src/lmflow/utils/versioning.py @@ -60,6 +60,10 @@ def is_vllm_available(): return _is_package_available("vllm") +def is_sglang_available(): + return _is_package_available("sglang") + + def is_flash_attn_available(): return _is_package_available("flash_attn", skippable=True) diff --git a/tests/datasets/conftest.py b/tests/datasets/conftest.py new file mode 100644 index 000000000..0f16a0a60 --- /dev/null +++ b/tests/datasets/conftest.py @@ -0,0 +1,13 @@ +import pytest + +from lmflow.args import DatasetArguments +from lmflow.datasets.dataset import Dataset + + +@pytest.fixture +def dataset_inference_conversation() -> Dataset: + dataset = Dataset(DatasetArguments(dataset_path=None)) + dataset = dataset.from_dict( + {"type": "conversation", "instances": [{"messages": [{"role": "user", "content": "Hello, how are you?"}]}]} + ) + return dataset \ No newline at end of file diff --git a/tests/pipeline/test_memory_safe_vllm_inferencer.py b/tests/pipeline/test_memory_safe_vllm_inferencer.py index 540b208b9..afa688b4d 100644 --- a/tests/pipeline/test_memory_safe_vllm_inferencer.py +++ b/tests/pipeline/test_memory_safe_vllm_inferencer.py @@ -32,8 +32,8 @@ results_path="./data/mem_safe_vllm_res.json", use_vllm=True, enable_decode_inference_result=False, - vllm_gpu_memory_utilization=0.95, - vllm_tensor_parallel_size=2, + inference_gpu_memory_utilization=0.95, + inference_tensor_parallel_size=2, ) diff --git a/tests/pipeline/test_sglang_infernecer.py b/tests/pipeline/test_sglang_infernecer.py new file mode 100644 index 000000000..469bcf2ed --- /dev/null +++ b/tests/pipeline/test_sglang_infernecer.py @@ -0,0 +1,35 @@ +import pytest + +from lmflow.args import DatasetArguments, InferencerArguments, ModelArguments +from lmflow.datasets.dataset import Dataset +from lmflow.models.hf_decoder_model import HFDecoderModel +from lmflow.pipeline.sglang_inferencer import SGLangInferencer + +from tests.datasets.conftest import dataset_inference_conversation + + +@pytest.fixture +def sglang_test_model_args() -> ModelArguments: + return ModelArguments(model_name_or_path="Qwen/Qwen3-4B-Instruct-2507") + +@pytest.fixture +def sglang_test_inferencer_args() -> InferencerArguments: + return InferencerArguments(inference_engine="sglang") + +if __name__ == "__main__": + def test_sglang_inferencer( + dataset_inference_conversation: Dataset, + sglang_test_model_args: ModelArguments, + sglang_test_inferencer_args: InferencerArguments + ): + model = HFDecoderModel(model_args=sglang_test_model_args) + sglang_inferencer = SGLangInferencer( + data_args=dataset_inference_conversation.data_args, + model_args=sglang_test_model_args, + inferencer_args=sglang_test_inferencer_args + ) + sglang_inferencer.inference( + model=model, + dataset=dataset_inference_conversation, + ) + test_sglang_inferencer() \ No newline at end of file From bd29f99fe325b7c3935af1d3bf3486c4d4c2626b Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Sun, 23 Nov 2025 18:28:31 +0800 Subject: [PATCH 2/6] [fix] sglang bug fix, unit test --- src/lmflow/models/hf_decoder_model.py | 1 + src/lmflow/models/hf_model_mixin.py | 2 +- src/lmflow/pipeline/sglang_inferencer.py | 13 ++++--- tests/datasets/conftest.py | 14 ++++++++ tests/pipeline/test_sglang_infernecer.py | 45 ++++++++++++++---------- 5 files changed, 51 insertions(+), 24 deletions(-) diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index a4ec719fd..11bb72318 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -504,6 +504,7 @@ def preprocess_conversation(sample): return sample_out + # TODO: investigate performance issue dataset = dataset.map( preprocess_conversation, num_proc=dataset.data_args.preprocessing_num_workers, diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index 6ea1cf1c5..f906ffd31 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -479,7 +479,7 @@ def __prepare_model_for_sglang_inference( sgl_server_args = ServerArgs( model_path=model_args.model_name_or_path, mem_fraction_static=gpu_memory_utilization, - tensor_parallel_size=tensor_parallel_size, + tp_size=tensor_parallel_size, ) self.backend_model_for_inference = Engine(server_args=sgl_server_args) diff --git a/src/lmflow/pipeline/sglang_inferencer.py b/src/lmflow/pipeline/sglang_inferencer.py index 712413d1d..21f2cd385 100644 --- a/src/lmflow/pipeline/sglang_inferencer.py +++ b/src/lmflow/pipeline/sglang_inferencer.py @@ -44,12 +44,14 @@ def _parse_args_to_sampling_params( self, inference_args: InferencerArguments, ) -> dict: + if inference_args.use_beam_search: + logger.warning("`use_beam_search` is ignored, as SGLang does not support currently.") + sampling_params = { - "use_beam_search": inference_args.use_beam_search, "n": inference_args.num_output_sequences, "temperature": inference_args.temperature + 1e-6, - "max_tokens": inference_args.max_new_tokens, - "seed": inference_args.random_seed, + "max_new_tokens": inference_args.max_new_tokens, + "sampling_seed": inference_args.random_seed, "top_p": inference_args.top_p, "top_k": inference_args.top_k, "stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids, @@ -70,15 +72,18 @@ def inference( else: sampling_params = self.sampling_params + # TODO: we need lmflow data sample protocol for better programming experience, data tracking, etc. model_input = model.prepare_inputs_for_inference( dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, inference_engine="sglang", ) + # handling n>1 since we don't want one-to-many mapping + model_input = [sample for sample in model_input for _ in range(sampling_params["n"])] outputs = model.inference( inputs=model_input, - sampling_params=sampling_params, + sampling_params=sampling_params.copy().update({"n": 1}), release_gpu=release_gpu, inference_engine="sglang", gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, diff --git a/tests/datasets/conftest.py b/tests/datasets/conftest.py index 0f16a0a60..5660d4207 100644 --- a/tests/datasets/conftest.py +++ b/tests/datasets/conftest.py @@ -10,4 +10,18 @@ def dataset_inference_conversation() -> Dataset: dataset = dataset.from_dict( {"type": "conversation", "instances": [{"messages": [{"role": "user", "content": "Hello, how are you?"}]}]} ) + return dataset + +@pytest.fixture +def dataset_inference_conversation_batch() -> Dataset: + dataset = Dataset(DatasetArguments(dataset_path=None)) + dataset = dataset.from_dict( + { + "type": "conversation", + "instances": [ + {"messages": [{"role": "user", "content": "Hello, how are you?"}]}, + {"messages": [{"role": "user", "content": "What's the capital of France?"}]}, + ] + } + ) return dataset \ No newline at end of file diff --git a/tests/pipeline/test_sglang_infernecer.py b/tests/pipeline/test_sglang_infernecer.py index 469bcf2ed..1aba53481 100644 --- a/tests/pipeline/test_sglang_infernecer.py +++ b/tests/pipeline/test_sglang_infernecer.py @@ -5,7 +5,7 @@ from lmflow.models.hf_decoder_model import HFDecoderModel from lmflow.pipeline.sglang_inferencer import SGLangInferencer -from tests.datasets.conftest import dataset_inference_conversation +from tests.datasets.conftest import dataset_inference_conversation_batch @pytest.fixture @@ -14,22 +14,29 @@ def sglang_test_model_args() -> ModelArguments: @pytest.fixture def sglang_test_inferencer_args() -> InferencerArguments: - return InferencerArguments(inference_engine="sglang") + return InferencerArguments( + inference_engine="sglang", + inference_gpu_memory_utilization=0.8, + num_output_sequences=2, + ) -if __name__ == "__main__": - def test_sglang_inferencer( - dataset_inference_conversation: Dataset, - sglang_test_model_args: ModelArguments, - sglang_test_inferencer_args: InferencerArguments - ): - model = HFDecoderModel(model_args=sglang_test_model_args) - sglang_inferencer = SGLangInferencer( - data_args=dataset_inference_conversation.data_args, - model_args=sglang_test_model_args, - inferencer_args=sglang_test_inferencer_args - ) - sglang_inferencer.inference( - model=model, - dataset=dataset_inference_conversation, - ) - test_sglang_inferencer() \ No newline at end of file +def test_sglang_inferencer( + dataset_inference_conversation_batch: Dataset, + sglang_test_model_args: ModelArguments, + sglang_test_inferencer_args: InferencerArguments +): + model = HFDecoderModel(model_args=sglang_test_model_args) + sglang_inferencer = SGLangInferencer( + data_args=dataset_inference_conversation_batch.data_args, + model_args=sglang_test_model_args, + inferencer_args=sglang_test_inferencer_args + ) + res = sglang_inferencer.inference( + model=model, + dataset=dataset_inference_conversation_batch, + ) + assert len(res) == 4 + assert res[0]['input'] == dataset_inference_conversation_batch.backend_dataset[0]['templated'] + assert res[1]['input'] == dataset_inference_conversation_batch.backend_dataset[0]['templated'] + assert res[2]['input'] == dataset_inference_conversation_batch.backend_dataset[1]['templated'] + assert res[3]['input'] == dataset_inference_conversation_batch.backend_dataset[1]['templated'] \ No newline at end of file From 83685356ea7e9122944e0c12d0c4ed106cf280fb Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Sun, 23 Nov 2025 18:36:58 +0800 Subject: [PATCH 3/6] [versioning] bump lmflow version to 1.1.0 --- requirements.txt | 2 +- src/lmflow/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3db38164d..caaaa3cb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ packaging numpy -datasets==2.14.6 +datasets==3.6.0 tokenizers>=0.13.3 peft>=0.10.0 torch>=2.0.1 diff --git a/src/lmflow/version.py b/src/lmflow/version.py index 5becc17c0..6849410aa 100644 --- a/src/lmflow/version.py +++ b/src/lmflow/version.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.1.0" From 235935d7a3caf5607e9e3bee3ca11edf0fa1c57c Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Sun, 23 Nov 2025 18:40:41 +0800 Subject: [PATCH 4/6] [ci] lint --- src/lmflow/args.py | 45 ++++++++++++++---------- src/lmflow/models/hf_decoder_model.py | 43 ++++++++++++---------- src/lmflow/models/hf_model_mixin.py | 12 +++---- src/lmflow/pipeline/sglang_inferencer.py | 5 ++- src/lmflow/utils/deprecated.py | 44 ++++++++++++----------- tests/datasets/conftest.py | 7 ++-- tests/pipeline/test_sglang_infernecer.py | 22 ++++++------ 7 files changed, 98 insertions(+), 80 deletions(-) diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 1a1228ff0..4d9f143ff 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -995,35 +995,44 @@ class InferencerArguments: # vllm inference args use_vllm: Optional[bool] = field( - default=None, - metadata={"help": "Whether to use VLLM for inference, By default None. Deprecated, use inference_engine instead."} + default=None, + metadata={ + "help": "Whether to use VLLM for inference, By default None. Deprecated, use inference_engine instead." + }, ) vllm_tensor_parallel_size: Optional[int] = field( - default=None, - metadata={"help": "The tensor parallel size for VLLM inference. Deprecated, use inference_tensor_parallel_size instead."} + default=None, + metadata={ + "help": ( + "The tensor parallel size for VLLM inference. Deprecated, use inference_tensor_parallel_size instead." + ) + }, ) vllm_gpu_memory_utilization: Optional[float] = field( - default=None, - metadata={"help": "The GPU memory utilization for VLLM inference. Deprecated, use inference_gpu_memory_utilization instead."} + default=None, + metadata={ + "help": ( + "The GPU memory utilization for VLLM inference. " + "Deprecated, use inference_gpu_memory_utilization instead." + ) + }, ) - + # inference engine args inference_engine: Optional[str] = field( - default="huggingface", + default="huggingface", metadata={ "help": "The inference engine to use, by default huggingface.", - "choices": ["huggingface", "vllm", "sglang"] - } + "choices": ["huggingface", "vllm", "sglang"], + }, ) inference_tensor_parallel_size: Optional[int] = field( - default=1, - metadata={"help": "The tensor parallel size for inference."} + default=1, metadata={"help": "The tensor parallel size for inference."} ) inference_gpu_memory_utilization: Optional[float] = field( - default=0.95, - metadata={"help": "The GPU memory utilization for inference."} + default=0.95, metadata={"help": "The GPU memory utilization for inference."} ) - + # Args for result saving save_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."}) results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."}) @@ -1044,21 +1053,21 @@ def __post_init__(self): raise ValueError("The results_path must be a json file.") else: Path(self.results_path).parent.mkdir(parents=True, exist_ok=True) - + if self.use_vllm is True: logger.warning( "Inference engine is set to vllm. You've specified `use_vllm`. This argument is deprecated and " "will be removed in a future version. Please use `inference_engine` instead." ) self.inference_engine = "vllm" - + if self.vllm_tensor_parallel_size is not None: logger.warning( "You've specified `vllm_tensor_parallel_size`. This argument is deprecated and " "will be removed in a future version. Please use `inference_tensor_parallel_size` instead." ) self.inference_tensor_parallel_size = self.vllm_tensor_parallel_size - + if self.vllm_gpu_memory_utilization is not None: logger.warning( "You've specified `vllm_gpu_memory_utilization`. This argument is deprecated and " diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 11bb72318..d711daf96 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -276,20 +276,23 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]: @deprecated_args( use_vllm={ - 'replacement': 'inference_engine', - 'mapper': lambda x: 'vllm' if x is True else 'huggingface', - 'message': "use_vllm is deprecated and will be removed in a future version. Please use `inference_engine='vllm'` instead." + "replacement": "inference_engine", + "mapper": lambda x: "vllm" if x is True else "huggingface", + "message": ( + "use_vllm is deprecated and will be removed in a future version. " + "Please use `inference_engine='vllm'` instead." + ), } ) def inference( - self, - inputs: Union[str, list[str], torch.Tensor], + self, + inputs: Union[str, list[str], torch.Tensor], sampling_params: Optional[Union[dict, "SamplingParams"]] = None, - release_gpu: bool = False, - inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", + release_gpu: bool = False, + inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", gpu_memory_utilization: Optional[float] = None, tensor_parallel_size: Optional[int] = None, - **kwargs + **kwargs, ): """ Perform generation process of the model. @@ -318,14 +321,14 @@ def inference( """ if isinstance(inputs, str): inputs = [inputs] - + if not self._activated: self.activate_model_for_inference( inference_engine=inference_engine, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, ) - + if inference_engine == "vllm": res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params) elif inference_engine == "sglang": @@ -418,29 +421,31 @@ def __vllm_inference( final_output.append({"input": output.prompt, "output": output_list}) return final_output - + def __sglang_inference( self, inputs: list[str], sampling_params: Optional[dict] = None, ): - """Perform SGLang inference process of the model. - """ + """Perform SGLang inference process of the model.""" sglang_outputs = self.backend_model_for_inference.generate( prompt=inputs, sampling_params=sampling_params, ) # TODO: unified lmflow sample format for idx, output in enumerate(sglang_outputs): - output['input'] = inputs[idx] - output['output'] = output.pop('text') + output["input"] = inputs[idx] + output["output"] = output.pop("text") return sglang_outputs @deprecated_args( use_vllm={ - 'replacement': 'inference_engine', - 'mapper': lambda x: 'vllm' if x is True else 'huggingface', - 'message': "use_vllm is deprecated and will be removed in a future version. Please use `inference_engine='vllm'` instead." + "replacement": "inference_engine", + "mapper": lambda x: "vllm" if x is True else "huggingface", + "message": ( + "use_vllm is deprecated and will be removed in a future version. " + "Please use `inference_engine='vllm'` instead." + ), } ) def prepare_inputs_for_inference( @@ -528,7 +533,7 @@ def preprocess_conversation(sample): inference_inputs = ray.data.from_items( inference_inputs ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])} - + if inference_engine == "sglang" and self.tokenizer.bos_token: # in consistent with sglang bench_serving.py demo inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs] diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index f906ffd31..9dda62409 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -25,8 +25,7 @@ from lmflow.models.base_model import BaseModel from lmflow.utils.constants import LMFLOW_LORA_TARGET_MODULES_MAPPING from lmflow.utils.envs import is_accelerate_env -from lmflow.utils.versioning import is_deepspeed_available, is_vllm_available, is_sglang_available - +from lmflow.utils.versioning import is_deepspeed_available, is_sglang_available, is_vllm_available logger = logging.getLogger(__name__) @@ -452,7 +451,7 @@ def __prepare_model_for_vllm_inference( ): if not is_vllm_available(): raise ImportError('VLLM is not available. Please install via `pip install -e ".[vllm]"`.') - + from vllm import LLM self.backend_model_for_inference = LLM( @@ -463,7 +462,7 @@ def __prepare_model_for_vllm_inference( gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, ) - + def __prepare_model_for_sglang_inference( self, model_args: ModelArguments, @@ -472,10 +471,10 @@ def __prepare_model_for_sglang_inference( ): if not is_sglang_available(): raise ImportError('SGLang is not available. Please install via `pip install -e ".[sglang]"`.') - + from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs - + sgl_server_args = ServerArgs( model_path=model_args.model_name_or_path, mem_fraction_static=gpu_memory_utilization, @@ -552,6 +551,7 @@ def deactivate_model_for_inference( if inference_engine == "vllm": from vllm.distributed.parallel_state import destroy_model_parallel + destroy_model_parallel() del self.backend_model_for_inference.llm_engine.model_executor.driver_worker del self.backend_model_for_inference diff --git a/src/lmflow/pipeline/sglang_inferencer.py b/src/lmflow/pipeline/sglang_inferencer.py index 21f2cd385..442975ec4 100644 --- a/src/lmflow/pipeline/sglang_inferencer.py +++ b/src/lmflow/pipeline/sglang_inferencer.py @@ -2,7 +2,6 @@ # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. import json import logging - from typing import Optional, Union from transformers import AutoTokenizer @@ -46,7 +45,7 @@ def _parse_args_to_sampling_params( ) -> dict: if inference_args.use_beam_search: logger.warning("`use_beam_search` is ignored, as SGLang does not support currently.") - + sampling_params = { "n": inference_args.num_output_sequences, "temperature": inference_args.temperature + 1e-6, @@ -56,7 +55,7 @@ def _parse_args_to_sampling_params( "top_k": inference_args.top_k, "stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids, } - + return sampling_params def inference( diff --git a/src/lmflow/utils/deprecated.py b/src/lmflow/utils/deprecated.py index c2ee06648..d385b53a9 100644 --- a/src/lmflow/utils/deprecated.py +++ b/src/lmflow/utils/deprecated.py @@ -5,22 +5,22 @@ import functools import inspect import warnings -from typing import Any, Callable, Dict +from typing import Any, Callable -__all__ = ['deprecated_args'] +__all__ = ["deprecated_args"] -def deprecated_args(**deprecated_params: Dict[str, Any]): +def deprecated_args(**deprecated_params: dict[str, Any]): """ Decorator to handle deprecated function arguments. - + Args: **deprecated_params: Mapping of deprecated argument names to their configuration. Each value should be a dict with: - 'replacement': Name of the new argument (optional) - 'mapper': Function to map old value to new value (optional) - 'message': Custom deprecation message (optional) - + Example: @deprecated_args( use_vllm={ @@ -32,6 +32,7 @@ def deprecated_args(**deprecated_params: Dict[str, Any]): def my_function(inference_engine='huggingface', **kwargs): pass """ + def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): @@ -39,39 +40,40 @@ def wrapper(*args, **kwargs): sig = inspect.signature(func) bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() - + # Check for deprecated arguments in kwargs for old_arg, config in deprecated_params.items(): if old_arg in kwargs: old_value = kwargs.pop(old_arg) - + # Build deprecation message - if 'message' in config: - message = config['message'] + if "message" in config: + message = config["message"] else: - replacement = config.get('replacement', 'a different argument') + replacement = config.get("replacement", "a different argument") message = ( f"'{old_arg}' is deprecated and will be removed in a future version. " f"Please use '{replacement}' instead." ) - + warnings.warn(message, DeprecationWarning, stacklevel=2) - + # Map old value to new argument if specified - if 'replacement' in config: - new_arg = config['replacement'] - + if "replacement" in config: + new_arg = config["replacement"] + # Apply mapper function if provided - if 'mapper' in config: - new_value = config['mapper'](old_value) + if "mapper" in config: + new_value = config["mapper"](old_value) else: new_value = old_value - + # Only set the new argument if it wasn't already provided if new_arg not in kwargs: kwargs[new_arg] = new_value - + return func(*args, **kwargs) - + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/tests/datasets/conftest.py b/tests/datasets/conftest.py index 5660d4207..90f299161 100644 --- a/tests/datasets/conftest.py +++ b/tests/datasets/conftest.py @@ -12,16 +12,17 @@ def dataset_inference_conversation() -> Dataset: ) return dataset + @pytest.fixture def dataset_inference_conversation_batch() -> Dataset: dataset = Dataset(DatasetArguments(dataset_path=None)) dataset = dataset.from_dict( { - "type": "conversation", + "type": "conversation", "instances": [ {"messages": [{"role": "user", "content": "Hello, how are you?"}]}, {"messages": [{"role": "user", "content": "What's the capital of France?"}]}, - ] + ], } ) - return dataset \ No newline at end of file + return dataset diff --git a/tests/pipeline/test_sglang_infernecer.py b/tests/pipeline/test_sglang_infernecer.py index 1aba53481..bd3fe4646 100644 --- a/tests/pipeline/test_sglang_infernecer.py +++ b/tests/pipeline/test_sglang_infernecer.py @@ -1,6 +1,6 @@ import pytest -from lmflow.args import DatasetArguments, InferencerArguments, ModelArguments +from lmflow.args import InferencerArguments, ModelArguments from lmflow.datasets.dataset import Dataset from lmflow.models.hf_decoder_model import HFDecoderModel from lmflow.pipeline.sglang_inferencer import SGLangInferencer @@ -12,31 +12,33 @@ def sglang_test_model_args() -> ModelArguments: return ModelArguments(model_name_or_path="Qwen/Qwen3-4B-Instruct-2507") + @pytest.fixture def sglang_test_inferencer_args() -> InferencerArguments: return InferencerArguments( - inference_engine="sglang", + inference_engine="sglang", inference_gpu_memory_utilization=0.8, num_output_sequences=2, ) + def test_sglang_inferencer( dataset_inference_conversation_batch: Dataset, sglang_test_model_args: ModelArguments, - sglang_test_inferencer_args: InferencerArguments + sglang_test_inferencer_args: InferencerArguments, ): model = HFDecoderModel(model_args=sglang_test_model_args) sglang_inferencer = SGLangInferencer( - data_args=dataset_inference_conversation_batch.data_args, - model_args=sglang_test_model_args, - inferencer_args=sglang_test_inferencer_args + data_args=dataset_inference_conversation_batch.data_args, + model_args=sglang_test_model_args, + inferencer_args=sglang_test_inferencer_args, ) res = sglang_inferencer.inference( model=model, dataset=dataset_inference_conversation_batch, ) assert len(res) == 4 - assert res[0]['input'] == dataset_inference_conversation_batch.backend_dataset[0]['templated'] - assert res[1]['input'] == dataset_inference_conversation_batch.backend_dataset[0]['templated'] - assert res[2]['input'] == dataset_inference_conversation_batch.backend_dataset[1]['templated'] - assert res[3]['input'] == dataset_inference_conversation_batch.backend_dataset[1]['templated'] \ No newline at end of file + assert res[0]["input"] == dataset_inference_conversation_batch.backend_dataset[0]["templated"] + assert res[1]["input"] == dataset_inference_conversation_batch.backend_dataset[0]["templated"] + assert res[2]["input"] == dataset_inference_conversation_batch.backend_dataset[1]["templated"] + assert res[3]["input"] == dataset_inference_conversation_batch.backend_dataset[1]["templated"] From 0bf8d91ab44499bda43669adc2e64aaa9661b703 Mon Sep 17 00:00:00 2001 From: Eric Date: Wed, 26 Nov 2025 00:32:04 +0800 Subject: [PATCH 5/6] [sglang] deterministic inference --- src/lmflow/args.py | 28 +++++++++++++++++ src/lmflow/models/hf_decoder_model.py | 19 +++++++++++- src/lmflow/models/hf_model_mixin.py | 10 ++++++- src/lmflow/pipeline/sglang_inferencer.py | 3 ++ tests/pipeline/test_sglang_infernecer.py | 38 ++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 2 deletions(-) diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 4d9f143ff..7d935d9c7 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -947,6 +947,10 @@ class InferencerArguments: default=False, metadata={"help": "whether turn on true random sampling during inference."}, ) + return_logprob: Optional[bool] = field( + default=False, + metadata={"help": "whether to return log probability during inference."}, + ) use_accelerator: Optional[bool] = field( default=None, metadata={"help": "[Deprecated] Whether to use Huggingface Accelerator instead of Deepspeed"}, @@ -1032,6 +1036,16 @@ class InferencerArguments: inference_gpu_memory_utilization: Optional[float] = field( default=0.95, metadata={"help": "The GPU memory utilization for inference."} ) + enable_deterministic_inference: bool = field( + default=False, + metadata={"help": "Whether to enable deterministic inference. Only supported for SGLang inference engine currently."}, + ) + attention_backend: Optional[str] = field( + default=None, + metadata={"help": ("The attention backend to use. Only supported for SGLang inference engine currently. " + "Please leave it as None to let SGLang automatically choose if you're not sure.") + }, + ) # Args for result saving save_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."}) @@ -1074,6 +1088,20 @@ def __post_init__(self): "will be removed in a future version. Please use `inference_gpu_memory_utilization` instead." ) self.inference_gpu_memory_utilization = self.vllm_gpu_memory_utilization + + if self.inference_engine != "sglang": + if self.return_logprob: + logger.warning("`return_logprob` is only supported for SGLang inference engine currently. ") + + if self.inference_engine == "sglang": + if self.enable_deterministic_inference: + if self.attention_backend is None: + self.attention_backend = "fa3" + logger.warning("`enable_deterministic_inference` is enabled, but `attention_backend` is not specified. " + "Using `fa3` as the attention backend by default.") + else: + assert self.attention_backend in ["fa3", "flashinfer", "triton"], ( + "Invalid attention backend. Please choose from 'fa3', 'flashinfer', or 'triton'.") @dataclass diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index d711daf96..221712e12 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -288,10 +288,13 @@ def inference( self, inputs: Union[str, list[str], torch.Tensor], sampling_params: Optional[Union[dict, "SamplingParams"]] = None, + return_logprob: bool = False, release_gpu: bool = False, inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", gpu_memory_utilization: Optional[float] = None, tensor_parallel_size: Optional[int] = None, + enable_deterministic_inference: bool = False, + attention_backend: Optional[str] = None, **kwargs, ): """ @@ -305,6 +308,8 @@ def inference( When the inference engine is "huggingface", this should be a tensor. sampling_params : Optional[Union[dict, "SamplingParams"]], optional The sampling parameters to use, by default None. + return_logprob : bool, optional + Whether to return log probability during inference, by default False. release_gpu : bool, optional Whether to release the GPU resource after inference, by default False. inference_engine : Literal["huggingface", "vllm", "sglang"], optional @@ -313,6 +318,10 @@ def inference( The GPU memory utilization to use, by default None. tensor_parallel_size : int, optional The tensor parallel size to use, by default None. + enable_deterministic_inference : bool, optional + Whether to enable deterministic inference, by default False. + attention_backend : Optional[str], optional + The attention backend to use, by default None. Returns ------------ @@ -327,12 +336,18 @@ def inference( inference_engine=inference_engine, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, + enable_deterministic_inference=enable_deterministic_inference, + attention_backend=attention_backend, ) if inference_engine == "vllm": res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params) elif inference_engine == "sglang": - res = self.__sglang_inference(inputs=inputs, sampling_params=sampling_params) + res = self.__sglang_inference( + inputs=inputs, + sampling_params=sampling_params, + return_logprob=return_logprob, + ) else: res = self.__inference(inputs=inputs, **kwargs) @@ -426,11 +441,13 @@ def __sglang_inference( self, inputs: list[str], sampling_params: Optional[dict] = None, + return_logprob: bool = False, ): """Perform SGLang inference process of the model.""" sglang_outputs = self.backend_model_for_inference.generate( prompt=inputs, sampling_params=sampling_params, + return_logprob=return_logprob, ) # TODO: unified lmflow sample format for idx, output in enumerate(sglang_outputs): diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index 9dda62409..dce5bd830 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -468,6 +468,8 @@ def __prepare_model_for_sglang_inference( model_args: ModelArguments, gpu_memory_utilization: Optional[float] = None, tensor_parallel_size: Optional[int] = None, + enable_deterministic_inference: bool = False, + attention_backend: Optional[str] = None, ): if not is_sglang_available(): raise ImportError('SGLang is not available. Please install via `pip install -e ".[sglang]"`.') @@ -479,6 +481,8 @@ def __prepare_model_for_sglang_inference( model_path=model_args.model_name_or_path, mem_fraction_static=gpu_memory_utilization, tp_size=tensor_parallel_size, + enable_deterministic_inference=enable_deterministic_inference, + attention_backend=attention_backend, ) self.backend_model_for_inference = Engine(server_args=sgl_server_args) @@ -509,6 +513,8 @@ def activate_model_for_inference( inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", gpu_memory_utilization: Optional[float] = None, tensor_parallel_size: Optional[int] = None, + enable_deterministic_inference: bool = False, + attention_backend: Optional[str] = None, ): if self._activated: logger.warning("You are trying to activate the model for inference, but it is already activated.") @@ -525,6 +531,8 @@ def activate_model_for_inference( model_args=self.model_args, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, + enable_deterministic_inference=enable_deterministic_inference, + attention_backend=attention_backend, ) else: self.__prepare_model_for_inference( @@ -558,7 +566,7 @@ def deactivate_model_for_inference( gc.collect() torch.cuda.empty_cache() elif inference_engine == "sglang": - pass + self.backend_model_for_inference.shutdown() else: self.backend_model.to("cpu") diff --git a/src/lmflow/pipeline/sglang_inferencer.py b/src/lmflow/pipeline/sglang_inferencer.py index 442975ec4..c2380cb9e 100644 --- a/src/lmflow/pipeline/sglang_inferencer.py +++ b/src/lmflow/pipeline/sglang_inferencer.py @@ -83,10 +83,13 @@ def inference( outputs = model.inference( inputs=model_input, sampling_params=sampling_params.copy().update({"n": 1}), + return_logprob=self.inferencer_args.return_logprob, release_gpu=release_gpu, inference_engine="sglang", gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size, + enable_deterministic_inference=self.inferencer_args.enable_deterministic_inference, + attention_backend=self.inferencer_args.attention_backend, ) if self.inferencer_args.save_results: diff --git a/tests/pipeline/test_sglang_infernecer.py b/tests/pipeline/test_sglang_infernecer.py index bd3fe4646..16795f386 100644 --- a/tests/pipeline/test_sglang_infernecer.py +++ b/tests/pipeline/test_sglang_infernecer.py @@ -1,4 +1,9 @@ +from typing import List, Tuple + +import numpy as np import pytest +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.server_args import ServerArgs from lmflow.args import InferencerArguments, ModelArguments from lmflow.datasets.dataset import Dataset @@ -19,6 +24,9 @@ def sglang_test_inferencer_args() -> InferencerArguments: inference_engine="sglang", inference_gpu_memory_utilization=0.8, num_output_sequences=2, + enable_deterministic_inference=True, + attention_backend="fa3", + return_logprob=True, ) @@ -27,6 +35,9 @@ def test_sglang_inferencer( sglang_test_model_args: ModelArguments, sglang_test_inferencer_args: InferencerArguments, ): + def parse_logprob(logprob_list: List[Tuple[float, int, None]]) -> List[float]: + return np.array([logprob for logprob, _, _ in logprob_list]) + model = HFDecoderModel(model_args=sglang_test_model_args) sglang_inferencer = SGLangInferencer( data_args=dataset_inference_conversation_batch.data_args, @@ -36,9 +47,36 @@ def test_sglang_inferencer( res = sglang_inferencer.inference( model=model, dataset=dataset_inference_conversation_batch, + release_gpu=True, ) assert len(res) == 4 assert res[0]["input"] == dataset_inference_conversation_batch.backend_dataset[0]["templated"] assert res[1]["input"] == dataset_inference_conversation_batch.backend_dataset[0]["templated"] assert res[2]["input"] == dataset_inference_conversation_batch.backend_dataset[1]["templated"] assert res[3]["input"] == dataset_inference_conversation_batch.backend_dataset[1]["templated"] + + # test consistency + sgl_server_args = ServerArgs( + model_path=sglang_test_model_args.model_name_or_path, + mem_fraction_static=sglang_test_inferencer_args.inference_gpu_memory_utilization, + tp_size=sglang_test_inferencer_args.inference_tensor_parallel_size, + enable_deterministic_inference=sglang_test_inferencer_args.enable_deterministic_inference, + attention_backend=sglang_test_inferencer_args.attention_backend, + ) + llm = Engine(server_args=sgl_server_args) + model_input = [ + sample for sample in dataset_inference_conversation_batch.backend_dataset['templated'] + for _ in range(sglang_test_inferencer_args.num_output_sequences) + ] + sglang_outputs = llm.generate( + prompt=model_input, + sampling_params=sglang_inferencer.sampling_params.copy().update({"n": 1}), + return_logprob=sglang_test_inferencer_args.return_logprob, + ) + logprobs_lmflow = [parse_logprob(x["meta_info"]["output_token_logprobs"]) for x in res] + logprobs_sglang = [parse_logprob(x["meta_info"]["output_token_logprobs"]) for x in sglang_outputs] + + assert all( + np.allclose(logprobs_lmflow, logprobs_sglang, atol=1e-10) + for logprobs_lmflow, logprobs_sglang in zip(logprobs_lmflow, logprobs_sglang) + ) From cf12a478176ebfca01b4e5ec5951bbe5eb056dce Mon Sep 17 00:00:00 2001 From: Eric Date: Wed, 26 Nov 2025 00:39:18 +0800 Subject: [PATCH 6/6] [ci] lint --- src/lmflow/args.py | 24 ++++++++++++++++-------- src/lmflow/models/hf_decoder_model.py | 4 ++-- tests/pipeline/test_sglang_infernecer.py | 16 +++++++--------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 7d935d9c7..162f85b8a 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -1038,12 +1038,17 @@ class InferencerArguments: ) enable_deterministic_inference: bool = field( default=False, - metadata={"help": "Whether to enable deterministic inference. Only supported for SGLang inference engine currently."}, + metadata={ + "help": "Whether to enable deterministic inference. Only supported for SGLang inference engine currently." + }, ) attention_backend: Optional[str] = field( default=None, - metadata={"help": ("The attention backend to use. Only supported for SGLang inference engine currently. " - "Please leave it as None to let SGLang automatically choose if you're not sure.") + metadata={ + "help": ( + "The attention backend to use. Only supported for SGLang inference engine currently. " + "Please leave it as None to let SGLang automatically choose if you're not sure." + ) }, ) @@ -1088,20 +1093,23 @@ def __post_init__(self): "will be removed in a future version. Please use `inference_gpu_memory_utilization` instead." ) self.inference_gpu_memory_utilization = self.vllm_gpu_memory_utilization - + if self.inference_engine != "sglang": if self.return_logprob: logger.warning("`return_logprob` is only supported for SGLang inference engine currently. ") - + if self.inference_engine == "sglang": if self.enable_deterministic_inference: if self.attention_backend is None: self.attention_backend = "fa3" - logger.warning("`enable_deterministic_inference` is enabled, but `attention_backend` is not specified. " - "Using `fa3` as the attention backend by default.") + logger.warning( + "`enable_deterministic_inference` is enabled, but `attention_backend` is not specified. " + "Using `fa3` as the attention backend by default." + ) else: assert self.attention_backend in ["fa3", "flashinfer", "triton"], ( - "Invalid attention backend. Please choose from 'fa3', 'flashinfer', or 'triton'.") + "Invalid attention backend. Please choose from 'fa3', 'flashinfer', or 'triton'." + ) @dataclass diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 221712e12..3c105b4e6 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -344,8 +344,8 @@ def inference( res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params) elif inference_engine == "sglang": res = self.__sglang_inference( - inputs=inputs, - sampling_params=sampling_params, + inputs=inputs, + sampling_params=sampling_params, return_logprob=return_logprob, ) else: diff --git a/tests/pipeline/test_sglang_infernecer.py b/tests/pipeline/test_sglang_infernecer.py index 16795f386..5dec650bc 100644 --- a/tests/pipeline/test_sglang_infernecer.py +++ b/tests/pipeline/test_sglang_infernecer.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import numpy as np import pytest from sglang.srt.entrypoints.engine import Engine @@ -9,8 +7,7 @@ from lmflow.datasets.dataset import Dataset from lmflow.models.hf_decoder_model import HFDecoderModel from lmflow.pipeline.sglang_inferencer import SGLangInferencer - -from tests.datasets.conftest import dataset_inference_conversation_batch +from tests.datasets.conftest import dataset_inference_conversation_batch # noqa: F401 @pytest.fixture @@ -31,11 +28,11 @@ def sglang_test_inferencer_args() -> InferencerArguments: def test_sglang_inferencer( - dataset_inference_conversation_batch: Dataset, + dataset_inference_conversation_batch: Dataset, # noqa: F811 sglang_test_model_args: ModelArguments, sglang_test_inferencer_args: InferencerArguments, ): - def parse_logprob(logprob_list: List[Tuple[float, int, None]]) -> List[float]: + def parse_logprob(logprob_list: list[tuple[float, int, None]]) -> list[float]: return np.array([logprob for logprob, _, _ in logprob_list]) model = HFDecoderModel(model_args=sglang_test_model_args) @@ -65,7 +62,8 @@ def parse_logprob(logprob_list: List[Tuple[float, int, None]]) -> List[float]: ) llm = Engine(server_args=sgl_server_args) model_input = [ - sample for sample in dataset_inference_conversation_batch.backend_dataset['templated'] + sample + for sample in dataset_inference_conversation_batch.backend_dataset["templated"] for _ in range(sglang_test_inferencer_args.num_output_sequences) ] sglang_outputs = llm.generate( @@ -75,8 +73,8 @@ def parse_logprob(logprob_list: List[Tuple[float, int, None]]) -> List[float]: ) logprobs_lmflow = [parse_logprob(x["meta_info"]["output_token_logprobs"]) for x in res] logprobs_sglang = [parse_logprob(x["meta_info"]["output_token_logprobs"]) for x in sglang_outputs] - + assert all( - np.allclose(logprobs_lmflow, logprobs_sglang, atol=1e-10) + np.allclose(logprobs_lmflow, logprobs_sglang, atol=1e-10) for logprobs_lmflow, logprobs_sglang in zip(logprobs_lmflow, logprobs_sglang) )