Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,12 @@ bash ./scripts/run_chatbot.sh output_models/finetuned_gpt2
```

> [!TIP]
> We recommend using vLLM for faster inference.
> We recommend using SGLang for faster batch inference.
>
> <details><summary>Faster inference using vLLM</summary>
> <details><summary>Faster inference using SGLang</summary>
>
>```bash
>bash ./scripts/run_vllm_inference.sh \
> --model_name_or_path Qwen/Qwen2-0.5B \
> --dataset_path data/alpaca/test_conversation \
> --output_dir data/inference_results \
>bash ./scripts/run_sglang_inference.sh
>```
>
> </details>
Expand Down
48 changes: 48 additions & 0 deletions examples/sglang_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# Copyright 2025 Statistics and Machine Learning Research Group. All rights reserved.
import logging
import os
import sys

from transformers import HfArgumentParser

from lmflow.args import (
AutoArguments,
DatasetArguments,
ModelArguments,
)
from lmflow.datasets import Dataset
from lmflow.models.auto_model import AutoModel
from lmflow.pipeline.auto_pipeline import AutoPipeline

logger = logging.getLogger(__name__)


def main():
# Parses arguments
pipeline_name = "sglang_inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

dataset = Dataset(data_args)
model = AutoModel.get_model(model_args, do_train=False)
inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name, model_args=model_args, data_args=data_args, pipeline_args=pipeline_args
)

res = inferencer.inference(
model,
dataset,
release_gpu=True,
)


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions scripts/run_sglang_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
python examples/sglang_inference.py \
--model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \
--dataset_path data/alpaca/test_conversation \
--output_dir output_data/sglang_inference_results \
--output_file_name results.json \
--inference_engine sglang \
--inference_gpu_memory_utilization 0.8 \
--num_output_sequences 2 \
--temperature 1.0 \
--max_new_tokens 2048 \
--top_p 0.95 \
--random_seed 42 \
--save_results True \
--results_path output_data/sglang_inference_results/results.json
2 changes: 2 additions & 0 deletions src/lmflow/pipeline/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lmflow.pipeline.evaluator import Evaluator
from lmflow.pipeline.finetuner import Finetuner
from lmflow.pipeline.inferencer import Inferencer
from lmflow.pipeline.sglang_inferencer import SGLangInferencer
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
from lmflow.pipeline.rm_tuner import RewardModelTuner
from lmflow.utils.versioning import is_package_version_at_least, is_ray_available, is_trl_available, is_vllm_available
Expand All @@ -12,6 +13,7 @@
"evaluator": Evaluator,
"finetuner": Finetuner,
"inferencer": Inferencer,
"sglang_inferencer": SGLangInferencer,
"rm_inferencer": RewardModelInferencer,
"rm_tuner": RewardModelTuner,
}
Expand Down