Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# python examples/generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0'
# python generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0,npu:1'

from xllm import ArgumentParser, LLM, RequestParams
from xllm import ArgumentParser, LLM, SamplingParams

# Create an LLM.
parser = ArgumentParser()
llm = LLM(**vars(parser.parse_args()))

# Create a reqeust params, include sampling params
request_params = RequestParams()
request_params.temperature = 0.8
request_params.top_p = 0.95
request_params.max_tokens = 10
# Create sampling params.
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=10,
)

# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
Expand All @@ -22,7 +23,7 @@
"The future of AI is",
]

outputs = llm.generate(prompts, request_params, True)
outputs = llm.generate(prompts, sampling_params=sampling_params)

# Print the outputs.
for i, output in enumerate(outputs):
Expand All @@ -31,4 +32,3 @@
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

llm.finish()

29 changes: 29 additions & 0 deletions examples/generate_beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# python examples/beam_search.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0'
# python beam_search.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0,npu:1'

from xllm import ArgumentParser, BeamSearchParams, LLM

# Create an LLM.
parser = ArgumentParser()
llm = LLM(**vars(parser.parse_args()))

beam_search_params = BeamSearchParams(
beam_width=2,
max_tokens=20,
)

outputs = llm.beam_search(
[
{"prompt": "Hello, my name is "},
{"prompt": "The president of the United States is "},
{"prompt": "The capital of France is "},
{"prompt": "The future of AI is "}
],
params=beam_search_params,
)

for output in outputs:
generated_text = output.sequences[0].text
print(f"Generated text: {generated_text!r}")

llm.finish()
21 changes: 10 additions & 11 deletions examples/generate_embedding.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# python examples/generate_embedding.py --model='/path/models/Qwen3-8B' --devices='npu:0'
# python examples/generate_embedding.py --model='/path/models/Qwen3-8B' --devices='npu:0' --runner pooling
# python generate_embedding.py --model='/path/models/Qwen3-8B' --devices='npu:0,npu:1'

from xllm import ArgumentParser, Embedding, RequestParams
from xllm import ArgumentParser, LLM, PoolingParams

# Create an EmbeddingLM.
# Create an embedding LLM.
parser = ArgumentParser()
emb = Embedding(**vars(parser.parse_args()))
args = parser.parse_args()
llm = LLM(**vars(args))

# Create a reqeust params, include sampling params
request_params = RequestParams()
request_params.is_embeddings = True
request_params.max_tokens = 1
# Create pooling params.
pooling_params = PoolingParams()

inputs = [
"Hello, my name is",
Expand All @@ -19,13 +18,13 @@
"The future of AI is",
]

outputs = emb.embedding(inputs, request_params, True)
outputs = llm.embed(inputs, pooling_params=pooling_params)

# Print the outputs.
for i, output in enumerate(outputs):
input_str = output.prompt
generated_embedding = output.outputs[0].embeddings
generated_embedding = output.outputs.embedding
print(f"Input: {input_str!r}, Generated embedding: {generated_embedding!r}")

emb.finish()
llm.finish()

102 changes: 58 additions & 44 deletions examples/generate_vlm.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,75 @@
# python generate_vlm.py --model /path/to/Qwen2.5-VL-7B-Instruct/ --disable_prefix_cache --disable_chunked_prefill --max_seqs_per_batch 4 --devices='npu:0' --enable_shm

from xllm import ArgumentParser, SamplingParams
from xllm import LLM
# from xllm import VLM
import base64
import os
import signal

from xllm import ArgumentParser, VLM, RequestParams, MMType, MMData
def encode_image_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"not found image: {file_path}")
with open(file_path, "rb") as image_file:
result = base64.b64encode(image_file.read()).decode("utf-8")
return result

from PIL import Image
from transformers import AutoImageProcessor

# Create an VLM.
parser = ArgumentParser()
args = parser.parse_args()

vlm = VLM(**vars(args))
processor = AutoImageProcessor.from_pretrained(args.model, trust_remote_code=True)

questions = ["简单介绍下图片"]
prompts = [
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
# vlm = VLM(**vars(args))
vlm = LLM(**vars(args))


image_1 = "./images/3.jpg"
image_2 = "./images/4.jpg"

# image_base64_1 = encode_image_from_file(image_1)
# image_base64_2 = encode_image_from_file(image_2)

# image_1 = f"data:image/jpeg;base64,{image_base64_1}"
# image_2 = f"data:image/jpeg;base64,{image_base64_2}"

requests = [
{
"prompt": (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>"
"请描述这张图片。<|im_end|>\n"
Comment thread
RobbieLeung marked this conversation as resolved.
"<|im_start|>assistant\n"
),
"multi_modal_data": {
"image": image_1,
},
},
{
"prompt": (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
"请对比这两张图片的主要区别。<|im_end|>\n"
"<|im_start|>assistant\n"
),
"multi_modal_data": {
"image": [image_1, image_2],
},
},
]

paths = ["00307664d4ce393b.png"]
images = []
for path in paths:
images.append(Image.open(path).convert("RGB"))

multi_modal_datas = []
for idx in range(len(images)):
print(f"Processing image: {paths[idx]}")
image = images[idx]

data = processor.preprocess([image], return_tensors="pt").data
mm_data = {
"pixel_values": data['pixel_values'],
"image_grid_thw": data['image_grid_thw'],
}
multi_modal_datas.append(MMData(MMType.IMAGE, mm_data))
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=50,
)


# Create a reqeust params, include sampling params
request_params = RequestParams()
request_params.temperature = 0
request_params.max_tokens = 1024

outputs = vlm.generate(prompts, multi_modal_datas, request_params, True)
outputs = vlm.generate(
requests,
sampling_params=sampling_params
)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

vlm.finish()


3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,8 @@ def parse_arguments() -> dict[str, Any]:
py_modules=["xllm/launch_xllm", "xllm/__init__",
"xllm/pybind/llm", "xllm/pybind/vlm",
"xllm/pybind/embedding", "xllm/pybind/util",
"xllm/pybind/args"],
"xllm/pybind/args", "xllm/pybind/params",
"xllm/pybind/errors", "xllm/pybind/mm_utils"],
entry_points={
'console_scripts': [
'xllm = xllm.launch_xllm:launch_xllm'
Expand Down
11 changes: 10 additions & 1 deletion xllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,19 @@ def _find_export_so_path() -> str:

from xllm.pybind.embedding import Embedding
from xllm.pybind.llm import LLM
from xllm.pybind.vlm import VLM
try:
from xllm.pybind.vlm import VLM
except Exception:
VLM = None
from xllm.pybind.args import ArgumentParser
from xllm.pybind.params import SamplingParams, BeamSearchParams, PoolingParams
from xllm_export import (
LLMMaster,
VLMMaster,
Options,
RequestParams,
RequestOutput,
Usage,
SequenceOutput,
Status,
StatusCode,
Expand All @@ -69,8 +74,12 @@ def _find_export_so_path() -> str:
"VLM",
"VLMMaster",
"Options",
"SamplingParams",
"BeamSearchParams",
"PoolingParams",
"RequestParams",
"RequestOutput",
"Usage",
"SequenceOutput",
"Status",
"StatusCode",
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/distributed_runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include(cc_library)

if(USE_NPU)
if(USE_NPU OR USE_CUDA)
include_directories(
${CMAKE_SOURCE_DIR}/third_party/spdlog/include
)
Expand Down
8 changes: 8 additions & 0 deletions xllm/core/distributed_runtime/llm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,14 @@ std::shared_ptr<Request> LLMMaster::generate_request(
// enable logprobs for best_of to generate sequence logprob
sampling_param.logprobs = true;
}
if (sampling_param.beam_width > 1) {
// beam search requires logprobs, and needs at least one top_logprob
// candidate for beam expansion.
sampling_param.logprobs = true;
if (sampling_param.top_logprobs == 0) {
sampling_param.top_logprobs = 1;
}
}
// sampling_param.do_sample = sp.do_sample;

SchedulerParam scheduler_param;
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/distributed_runtime/remote_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ RemoteWorker::RemoteWorker(int32_t global_rank,
const torch::Device& d,
std::unique_ptr<CommChannel> channel)
: global_rank_(global_rank), device_(d), channel_(std::move(channel)) {
wait_for_server_ready(server_address);
CHECK(wait_for_server_ready(server_address))
<< "Failed to wait for remote worker server ready: " << server_address
<< ", global_rank: " << global_rank_;
}

bool RemoteWorker::wait_for_server_ready(const std::string& server_address) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ namespace xllm {

bool xllm::SpawnWorkerServer::g_running_ = true;

namespace {
std::string get_backend_from_worker_type(const std::string& worker_type) {
if (worker_type == "LLM" || worker_type == "ELM") {
return "llm";
}
if (worker_type == "VLM" || worker_type == "EVLM" ||
worker_type == "MMEVLM") {
return "vlm";
}
if (worker_type == "REC") {
return "rec";
}
if (worker_type == "DIT") {
return "dit";
}
return "";
}
} // namespace

SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
int local_rank,
int global_rank,
Expand All @@ -45,10 +64,15 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
bool is_local,
bool enable_prefill_sp,
const std::string& task_type,
const std::string& worker_type) {
const std::string& worker_type,
const std::string& communication_backend) {
// TODO: pass whole xllm::runtime::Options here from main process.
xllm::runtime::Options runner_options;
const std::string backend = get_backend_from_worker_type(worker_type);
CHECK(!backend.empty()) << "Unsupported worker_type for backend mapping: "
<< worker_type;
runner_options.block_size(block_size)
.backend(backend)
.num_decoding_tokens(num_decoding_tokens)
.enable_prefill_sp(enable_prefill_sp)
.enable_schedule_overlap(false)
Expand All @@ -63,26 +87,31 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
FLAGS_enable_prefill_sp = enable_prefill_sp;
FLAGS_master_node_addr = master_node_addr;
FLAGS_block_size = block_size;
FLAGS_communication_backend = communication_backend;

std::atomic<bool> done(false);
#if defined(USE_NPU)
xllm::Device device("npu:" + std::to_string(device_idx));
const std::string device_type = xllm::Device::type_str();
const std::string device_str = device_type + ":" + std::to_string(device_idx);
xllm::Device device{torch::Device(device_str)};
device.set_device();

#if defined(USE_NPU)
device.init_device_context();
FLAGS_enable_atb_comm_multiprocess = true;
#endif

ParallelArgs parallel_args(global_rank, world_size, 1, nullptr, 1);
WorkerServer worker_server(local_rank,
master_node_addr,
done,
parallel_args,
device,
runner_options,
worker_type,
false);
worker_server_ = std::make_unique<WorkerServer>(local_rank,
master_node_addr,
done_,
parallel_args,
device,
runner_options,
worker_type,
false);
}

SpawnWorkerServer::~SpawnWorkerServer() = default;

void SpawnWorkerServer::handle_signal(int signum) { g_running_ = false; }

void SpawnWorkerServer::run() {
Expand Down
Loading
Loading