Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 116 additions & 103 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import queue
import random
import time
from contextlib import contextmanager
from threading import Thread
from typing import List, Optional

Expand Down Expand Up @@ -67,6 +68,21 @@
logger = get_logger("xpu_model_runner", "xpu_model_runner.log")


@contextmanager
def kv_signal_sender_context_manager(pd_disaggregation_mode):
sender = None
try:
sender = (
create_kv_signal_sender()
if pd_disaggregation_mode == "per_chunk" or pd_disaggregation_mode == "per_query"
else None
)
yield sender
finally:
if sender is not None:
destroy_kv_signal_sender(sender)


class XPUModelRunner(ModelRunnerBase):
""" """

Expand Down Expand Up @@ -1359,115 +1375,112 @@ class at the server level, which is too granular for ModelRunner.
"""
# 0. set debug level
# self._set_debug_level(0x1, model_forward_batch, is_dummy_run)
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.kv_signal_sender = create_kv_signal_sender()
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop() and not is_dummy_run:
self._execute_empty_input(self.forward_meta)
return None

# 2. Padding inputs for cuda grph

# 3. Execute model
if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
with kv_signal_sender_context_manager(self.pd_disaggregation_mode) as sender:
self.kv_signal_sender = sender
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop() and not is_dummy_run:
self._execute_empty_input(self.forward_meta)
return None

# 2. Padding inputs for cuda grph

# 3. Execute model
if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)

hidden_states = xpu_process_output(
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
hidden_states = xpu_process_output(
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
)

# 5. Speculative decode

# 6. Post Process
prompt_logprobs_list = None
if not self.speculative_decoding:
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)

model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=model_output if self.speculative_decoding else None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_logprobs_list=prompt_logprobs_list,
)
if self.speculative_decoding:
# base model post process
xpu_post_process_specualate(model_output_data, False, is_dummy_run)
else:
xpu_post_process_normal(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
async_output_queue=self.async_output_queue,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
# 5. Speculative decode

# 6. Post Process
prompt_logprobs_list = None
if not self.speculative_decoding:
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)

model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=model_output if self.speculative_decoding else None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_logprobs_list=prompt_logprobs_list,
)
if self.speculative_decoding:
# base model post process
xpu_post_process_specualate(model_output_data, False, is_dummy_run)
else:
xpu_post_process_normal(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
async_output_queue=self.async_output_queue,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)

# draft model propose
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)

# 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_xpu(
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,
self.speculative_decoding,
self.speculative_config.num_speculative_tokens,
)
# draft model propose
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)

if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
destroy_kv_signal_sender(self.kv_signal_sender)
# 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_xpu(
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,
self.speculative_decoding,
self.speculative_config.num_speculative_tokens,
)
return None

def _execute_empty_input(self, forward_meta) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/xpu_ci/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def safe_kill_cmd(cmd):
commands = [
"ps -efww | grep -E 'cache_transfer_manager.py' | grep -v grep | awk '{print $2}' | xargs echo",
"ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs echo",
"ps -efww | grep -E 'multiprocessing' | grep -v grep | awk '{print $2}' | xargs echo",
"ps -efww | grep -E 'fastdeploy' | grep -v grep | awk '{print $2}' | xargs echo",
f"ps -efww | grep -E '{port_num}' | grep -v grep | awk '{{print $2}}' | xargs echo",
f"lsof -t -i :{port_num} | xargs echo",
]
Expand Down Expand Up @@ -434,3 +436,27 @@ def restore_pd_env(original_values):
else:
os.environ[key] = original_values[key]
print(f"恢复环境变量: {key}={original_values[key]}")


def setup_pd_ep_env():
"""
设置PD分离+EP相关环境变量

Returns:
dict: 原始环境变量值,用于后续恢复
"""
original_values_pd = setup_pd_env()
original_values_ep = setup_ep_env()
original_values = {**original_values_pd, **original_values_ep}
return original_values


def restore_pd_ep_env(original_values):
"""
恢复PD分离相关环境变量

Args:
original_values: setup_ep_env()返回的原始环境变量值
"""
restore_env(original_values)
restore_pd_env(original_values)
Loading
Loading