From 22a8fc2dfe95bfbf4259a24b41c2c5607491b018 Mon Sep 17 00:00:00 2001 From: lixiang28 Date: Thu, 8 Jan 2026 17:33:09 +0800 Subject: [PATCH 1/2] feat: add dynamic prompt support for llm/mllm --- interface/ten_ai_base/llm.py | 35 +++++++++++++++++++++++++ interface/ten_ai_base/mllm.py | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/interface/ten_ai_base/llm.py b/interface/ten_ai_base/llm.py index 1b6534a..cdbed71 100644 --- a/interface/ten_ai_base/llm.py +++ b/interface/ten_ai_base/llm.py @@ -26,6 +26,8 @@ from .helper import AsyncQueue import json +DATA_LLM_IN_SET_PROMPT = "llm_client_set_prompt" + class AsyncLLMBaseExtension(AsyncExtension, ABC): """ @@ -50,9 +52,12 @@ def __init__(self, name: str): self.hit_default_cmd = False self.loop_task = None self.loop = None + self.dynamic_prompt: str | None = None # Dynamic prompt set by prompt_template extension + self.ten_env: AsyncTenEnv = None # type: ignore async def on_init(self, async_ten_env: AsyncTenEnv) -> None: await super().on_init(async_ten_env) + self.ten_env = async_ten_env async def on_start(self, async_ten_env: AsyncTenEnv) -> None: await super().on_start(async_ten_env) @@ -70,6 +75,36 @@ async def on_stop(self, async_ten_env: AsyncTenEnv) -> None: async def on_deinit(self, async_ten_env: AsyncTenEnv) -> None: await super().on_deinit(async_ten_env) + async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None: + """Handle incoming data messages.""" + data_name = data.get_name() + async_ten_env.log_debug(f"on_data name: {data_name}") + if data_name == DATA_LLM_IN_SET_PROMPT: + prompt_json, _ = data.get_property_to_json(None) + prompt_data = json.loads(prompt_json) + new_prompt = prompt_data.get("prompt", "") + self.dynamic_prompt = new_prompt + async_ten_env.log_info(f"Dynamic prompt set: {new_prompt[:100]}...") + await self.on_prompt_update(new_prompt) + + async def on_prompt_update(self, prompt: str) -> None: + """ + Called when a dynamic prompt is set by the prompt_template extension. + Subclasses can override this method to apply the new prompt. + Default implementation does nothing - subclass should update conversation context if needed. + """ + pass + + def get_prompt(self, config_prompt: str) -> str: + """ + Get the effective prompt to use. + Returns dynamic_prompt if set by prompt_template extension, + otherwise returns the config_prompt from property.json. + """ + if self.dynamic_prompt is not None: + return self.dynamic_prompt + return config_prompt + async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None: """ handle default commands diff --git a/interface/ten_ai_base/mllm.py b/interface/ten_ai_base/mllm.py index f578b20..72d4e31 100644 --- a/interface/ten_ai_base/mllm.py +++ b/interface/ten_ai_base/mllm.py @@ -7,6 +7,7 @@ import traceback from typing import final import uuid +import re from .struct import MLLMClientCreateResponse, MLLMClientFunctionCallOutput, MLLMClientMessageItem, MLLMClientRegisterTool, MLLMClientSendMessageItem, MLLMClientSetMessageContext, MLLMServerFunctionCall, MLLMServerInputTranscript, MLLMServerOutputTranscript, MLLMServerSessionReady, MLLMServerInterrupt @@ -44,6 +45,7 @@ DATA_MLLM_IN_CREATE_RESPONSE = "mllm_client_request_create_response" DATA_MLLM_IN_REGISTER_TOOL = "mllm_client_request_register_tool" DATA_MLLM_IN_FUNCTION_CALL_OUTPUT = "mllm_client_function_call_output" +DATA_MLLM_IN_SET_PROMPT = "mllm_client_set_prompt" class AsyncMLLMBaseExtension(AsyncExtension): @@ -60,6 +62,7 @@ def __init__(self, name: str): self.uuid = self._get_uuid() # Unique identifier for the current final turn self.leftover_bytes = b"" self.message_context: list[MLLMClientMessageItem] = [] # Context for the current message + self.dynamic_prompt: str | None = None # Dynamic prompt set by prompt_template extension # States for TTFW calculation self.first_audio_time: float | None = ( @@ -102,6 +105,13 @@ async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: function_call_output_json, _ = data.get_property_to_json(None) function_call_output = MLLMClientFunctionCallOutput.model_validate_json(function_call_output_json) await self.send_client_function_call_output(function_call_output) + elif data_name == DATA_MLLM_IN_SET_PROMPT: + prompt_json, _ = data.get_property_to_json(None) + prompt_data = json.loads(prompt_json) + new_prompt = prompt_data.get("prompt", "") + self.dynamic_prompt = new_prompt + ten_env.log_info(f"Dynamic prompt set: {new_prompt[:100]}...") + await self.on_prompt_update(new_prompt) async def on_stop(self, ten_env: AsyncTenEnv) -> None: @@ -235,6 +245,44 @@ async def send_audio(self, frame: AudioFrame, session_id: str | None) -> bool: """ raise NotImplementedError("This method should be implemented in subclasses.") + async def on_prompt_update(self, prompt: str) -> None: + """ + Called when a dynamic prompt is set by the prompt_template extension. + Subclasses can override this method to apply the new prompt. + Default implementation does nothing - subclass should update session if needed. + """ + pass + + def get_prompt(self, config_prompt: str, prompt_params: dict | None = None) -> str: + """ + Get the effective prompt to use, with optional template rendering. + + If prompt_params is provided and config_prompt contains {{placeholder}} patterns, + they will be replaced with the corresponding values from prompt_params. + + Args: + config_prompt: The prompt template from property.json (may contain {{placeholder}}) + prompt_params: Optional dict of parameters to render into the template + + Returns: + The rendered prompt string + """ + # If dynamic_prompt is set (by prompt_template extension), use it directly + if self.dynamic_prompt is not None: + return self.dynamic_prompt + + # If no params provided, return config_prompt as-is + if not prompt_params: + return config_prompt + + # Render template with {{placeholder}} syntax + def replace_placeholder(match): + key = match.group(1).strip() + return str(prompt_params.get(key, match.group(0))) # Keep original if not found + + pattern = r'\{\{(\w+)\}\}' + return re.sub(pattern, replace_placeholder, config_prompt) + @final async def send_server_session_ready(self, session: MLLMServerSessionReady) -> None: """ From d422b14b9f5ab1eac3dd290372f960a2a2722b97 Mon Sep 17 00:00:00 2001 From: Shawn Date: Mon, 12 Jan 2026 17:17:57 +0800 Subject: [PATCH 2/2] feat: support prompt params in llm --- interface/ten_ai_base/llm.py | 18 ++++++++++-- interface/ten_ai_base/llm2.py | 51 +++++++++++++++++++++++++++++++-- interface/ten_ai_base/struct.py | 3 +- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/interface/ten_ai_base/llm.py b/interface/ten_ai_base/llm.py index cdbed71..6687078 100644 --- a/interface/ten_ai_base/llm.py +++ b/interface/ten_ai_base/llm.py @@ -5,6 +5,7 @@ # from abc import ABC, abstractmethod import asyncio +import re import traceback from ten_runtime import ( @@ -95,15 +96,28 @@ async def on_prompt_update(self, prompt: str) -> None: """ pass - def get_prompt(self, config_prompt: str) -> str: + def get_prompt( + self, config_prompt: str, prompt_params: dict | None = None + ) -> str: """ Get the effective prompt to use. Returns dynamic_prompt if set by prompt_template extension, otherwise returns the config_prompt from property.json. + + If prompt_params is provided and config_prompt contains {{placeholder}} patterns, + they will be replaced with the corresponding values from prompt_params. """ if self.dynamic_prompt is not None: return self.dynamic_prompt - return config_prompt + if not prompt_params: + return config_prompt + + def replace_placeholder(match): + key = match.group(1).strip() + return str(prompt_params.get(key, match.group(0))) + + pattern = r"\{\{(\w+)\}\}" + return re.sub(pattern, replace_placeholder, config_prompt) async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None: """ diff --git a/interface/ten_ai_base/llm2.py b/interface/ten_ai_base/llm2.py index 3eb5752..29adb12 100644 --- a/interface/ten_ai_base/llm2.py +++ b/interface/ten_ai_base/llm2.py @@ -6,12 +6,21 @@ from abc import ABC, abstractmethod import asyncio import json +import re import traceback from typing import AsyncGenerator, Dict, Optional -from .struct import LLMRequest, LLMRequestAbort, LLMRequestRetrievePrompt, LLMResponse, LLMResponseRetrievePrompt +from .llm import DATA_LLM_IN_SET_PROMPT +from .struct import ( + LLMRequest, + LLMRequestAbort, + LLMRequestRetrievePrompt, + LLMResponse, + LLMResponseRetrievePrompt, +) from ten_runtime import ( AsyncExtension, + Data, ) from ten_runtime.async_ten_env import AsyncTenEnv from ten_runtime.cmd import Cmd @@ -35,6 +44,7 @@ def __init__(self, name: str): self.ten_env: AsyncTenEnv = None self._inflight: Dict[str, "AsyncLLM2BaseExtension._TaskCtx"] = {} self._lock = asyncio.Lock() + self.dynamic_prompt: str | None = None # Dynamic prompt set by prompt_template extension async def on_init(self, async_ten_env: AsyncTenEnv) -> None: @@ -44,6 +54,17 @@ async def on_init(self, async_ten_env: AsyncTenEnv) -> None: async def on_start(self, async_ten_env: AsyncTenEnv) -> None: await super().on_start(async_ten_env) + async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None: + """Handle incoming data messages (dynamic prompt updates).""" + data_name = data.get_name() + async_ten_env.log_debug(f"[LLM2Base] on_data: {data_name}") + if data_name == DATA_LLM_IN_SET_PROMPT: + prompt_json, _ = data.get_property_to_json(None) + prompt_data = json.loads(prompt_json) + new_prompt = prompt_data.get("prompt", "") + self.dynamic_prompt = new_prompt + async_ten_env.log_info(f"[LLM2Base] Dynamic prompt set: {new_prompt[:100]}...") + async def on_stop(self, async_ten_env: AsyncTenEnv) -> None: await self._cancel_all() await super().on_stop(async_ten_env) @@ -66,6 +87,13 @@ async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None: if not rid: raise RuntimeError("LLMRequest.request_id is required") + # Apply prompt template rendering if prompt is provided in request + prompt_params = getattr(req, "prompt_params", None) + if req.prompt is not None: + req.prompt = self.get_prompt(req.prompt, prompt_params) + elif self.dynamic_prompt is not None: + req.prompt = self.dynamic_prompt + # Reject duplicates instead of replacing async with self._lock: existing = self._inflight.get(rid) @@ -140,6 +168,25 @@ def __init__(self, task: asyncio.Task, cmd: Cmd, request_id: str): self.cmd = cmd self.request_id = request_id + def get_prompt(self, config_prompt: str, prompt_params: dict | None = None) -> str: + """ + Get the effective prompt to use, with optional template rendering. + + If prompt_params is provided and config_prompt contains {{placeholder}} patterns, + they will be replaced with the corresponding values from prompt_params. + """ + if self.dynamic_prompt is not None: + return self.dynamic_prompt + if not prompt_params: + return config_prompt + + def replace_placeholder(match): + key = match.group(1).strip() + return str(prompt_params.get(key, match.group(0))) + + pattern = r"\{\{(\w+)\}\}" + return re.sub(pattern, replace_placeholder, config_prompt) + async def _start_locked(self, ten_env: AsyncTenEnv, cmd: Cmd, req: LLMRequest) -> None: """Call with self._lock held. Starts a task and registers it in _inflight.""" rid = req.request_id @@ -228,4 +275,4 @@ def on_call_chat_completion( """Called when a chat completion is requested by cmd call. Implement this method to process the chat completion.""" raise NotImplementedError( "on_call_chat_completion must be implemented in the subclass" - ) \ No newline at end of file + ) diff --git a/interface/ten_ai_base/struct.py b/interface/ten_ai_base/struct.py index 509e034..239ed32 100644 --- a/interface/ten_ai_base/struct.py +++ b/interface/ten_ai_base/struct.py @@ -139,6 +139,7 @@ class LLMRequest(BaseModel): tools: Optional[list[LLMToolMetadata]] = None parameters: Optional[dict[str, Any]] = None prompt: Optional[str] = None + prompt_params: Optional[dict[str, Any]] = None class LLMRequestAbort(BaseModel): """ @@ -334,4 +335,4 @@ class MLLMServerFunctionCall(BaseModel): call_id: str name: str arguments: str - metadata: dict[str, Any] = {} \ No newline at end of file + metadata: dict[str, Any] = {}