Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions xllm/core/distributed_runtime/llm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,19 +439,16 @@ std::shared_ptr<Request> LLMMaster::generate_request(
std::optional<Call*> call,
OutputCallback callback) {
Timer timer;
std::optional<std::string> prompt;
if (sp.has_tools()) {
prompt = chat_template_->apply(messages, sp.tools, sp.chat_template_kwargs);
} else {
prompt = chat_template_->apply(messages, sp.chat_template_kwargs);
}

std::optional<std::string> prompt;
prompt = chat_template_->apply(messages, sp.tools, sp.chat_template_kwargs);
if (!prompt.has_value()) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"Failed to construct prompt from messages");
LOG(ERROR) << "Failed to construct prompt from messages";
return nullptr;
}

COUNTER_ADD(chat_template_latency_seconds, timer.elapsed_seconds());

return generate_request(
Expand Down
8 changes: 3 additions & 5 deletions xllm/core/distributed_runtime/rec_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,19 +524,17 @@ void RecMaster::handle_request(
}

Timer timer;

std::optional<std::string> prompt;
if (sp.has_tools()) {
prompt = chat_template_->apply(messages, sp.tools, sp.chat_template_kwargs);
} else {
prompt = chat_template_->apply(messages, sp.chat_template_kwargs);
}
prompt = chat_template_->apply(messages, sp.tools, sp.chat_template_kwargs);

if (!prompt.has_value()) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"Failed to construct prompt from messages");
LOG(ERROR) << "Failed to construct prompt from messages";
return;
}

COUNTER_ADD(chat_template_latency_seconds, timer.elapsed_seconds());

schedule_request(std::move(sp),
Expand Down
11 changes: 5 additions & 6 deletions xllm/core/framework/chat_template/jinja_chat_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,12 @@ class JinjaChatTemplate {

std::optional<std::string> apply(const ChatMessages& messages) const;

std::optional<std::string> apply(
const ChatMessages& messages,
const nlohmann::ordered_json& chat_template_kwargs) const;

std::optional<std::string> apply(
const ChatMessages& messages,
const std::vector<xllm::JsonTool>& json_tools,
const nlohmann::ordered_json& chat_template_kwargs) const;

// expose this function for testing
protected:
// apply the template to the values in the json object
std::optional<std::string> apply(nlohmann::ordered_json& messages) const;

Expand All @@ -53,7 +49,10 @@ class JinjaChatTemplate {
const nlohmann::ordered_json& tools,
const nlohmann::ordered_json& chat_template_kwargs) const;

private:
std::optional<std::string> apply(
const ChatMessages& messages,
const nlohmann::ordered_json& chat_template_kwargs) const;

nlohmann::ordered_json get_mm_content(const MMContentVec& vec) const;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ limitations under the License.

namespace xllm {

class TestableJinjaChatTemplate : public JinjaChatTemplate {
public:
TestableJinjaChatTemplate(const TokenizerArgs& args)
: JinjaChatTemplate(args) {}

using JinjaChatTemplate::apply;
};

TEST(JinjaChatTemplate, OpenChatModel) {
// clang-format off
const std::string template_str =
Expand Down Expand Up @@ -46,7 +54,7 @@ TEST(JinjaChatTemplate, OpenChatModel) {
args.chat_template(template_str);
args.bos_token("");
args.eos_token("<|end_of_turn|>");
JinjaChatTemplate template_(args);
TestableJinjaChatTemplate template_(args);
auto result = template_.apply(messages);
ASSERT_TRUE(result.has_value());

Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/request/request_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ struct RequestParams {

// JSON-based tools (replacing proto_tools)
std::vector<xllm::JsonTool> tools;

std::string tool_choice = "auto";
bool has_tools() const { return !tools.empty(); }

bool offline = false;

Expand Down
Loading