diff --git a/README.md b/README.md index c68bc605a..dd0d159e4 100644 --- a/README.md +++ b/README.md @@ -283,15 +283,12 @@ bash ./scripts/run_chatbot.sh output_models/finetuned_gpt2 ``` > [!TIP] -> We recommend using vLLM for faster inference. +> We recommend using SGLang for faster batch inference. > ->
Faster inference using vLLM +>
Faster inference using SGLang > >```bash ->bash ./scripts/run_vllm_inference.sh \ -> --model_name_or_path Qwen/Qwen2-0.5B \ -> --dataset_path data/alpaca/test_conversation \ -> --output_dir data/inference_results \ +>bash ./scripts/run_sglang_inference.sh >``` > >
diff --git a/examples/sglang_inference.py b/examples/sglang_inference.py new file mode 100644 index 000000000..5e294455f --- /dev/null +++ b/examples/sglang_inference.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# Copyright 2025 Statistics and Machine Learning Research Group. All rights reserved. +import logging +import os +import sys + +from transformers import HfArgumentParser + +from lmflow.args import ( + AutoArguments, + DatasetArguments, + ModelArguments, +) +from lmflow.datasets import Dataset +from lmflow.models.auto_model import AutoModel +from lmflow.pipeline.auto_pipeline import AutoPipeline + +logger = logging.getLogger(__name__) + + +def main(): + # Parses arguments + pipeline_name = "sglang_inferencer" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + + parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() + + dataset = Dataset(data_args) + model = AutoModel.get_model(model_args, do_train=False) + inferencer = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, model_args=model_args, data_args=data_args, pipeline_args=pipeline_args + ) + + res = inferencer.inference( + model, + dataset, + release_gpu=True, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/run_sglang_inference.sh b/scripts/run_sglang_inference.sh new file mode 100644 index 000000000..5bf259121 --- /dev/null +++ b/scripts/run_sglang_inference.sh @@ -0,0 +1,14 @@ +python examples/sglang_inference.py \ + --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \ + --dataset_path data/alpaca/test_conversation \ + --output_dir output_data/sglang_inference_results \ + --output_file_name results.json \ + --inference_engine sglang \ + --inference_gpu_memory_utilization 0.8 \ + --num_output_sequences 2 \ + --temperature 1.0 \ + --max_new_tokens 2048 \ + --top_p 0.95 \ + --random_seed 42 \ + --save_results True \ + --results_path output_data/sglang_inference_results/results.json \ No newline at end of file diff --git a/src/lmflow/pipeline/auto_pipeline.py b/src/lmflow/pipeline/auto_pipeline.py index cbc7d82e3..e55e74102 100644 --- a/src/lmflow/pipeline/auto_pipeline.py +++ b/src/lmflow/pipeline/auto_pipeline.py @@ -4,6 +4,7 @@ from lmflow.pipeline.evaluator import Evaluator from lmflow.pipeline.finetuner import Finetuner from lmflow.pipeline.inferencer import Inferencer +from lmflow.pipeline.sglang_inferencer import SGLangInferencer from lmflow.pipeline.rm_inferencer import RewardModelInferencer from lmflow.pipeline.rm_tuner import RewardModelTuner from lmflow.utils.versioning import is_package_version_at_least, is_ray_available, is_trl_available, is_vllm_available @@ -12,6 +13,7 @@ "evaluator": Evaluator, "finetuner": Finetuner, "inferencer": Inferencer, + "sglang_inferencer": SGLangInferencer, "rm_inferencer": RewardModelInferencer, "rm_tuner": RewardModelTuner, }