Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions interface/ten_ai_base/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
from abc import ABC, abstractmethod
import asyncio
import re
import traceback

from ten_runtime import (
Expand All @@ -26,6 +27,8 @@
from .helper import AsyncQueue
import json

DATA_LLM_IN_SET_PROMPT = "llm_client_set_prompt"


class AsyncLLMBaseExtension(AsyncExtension, ABC):
"""
Expand All @@ -50,9 +53,12 @@ def __init__(self, name: str):
self.hit_default_cmd = False
self.loop_task = None
self.loop = None
self.dynamic_prompt: str | None = None # Dynamic prompt set by prompt_template extension
self.ten_env: AsyncTenEnv = None # type: ignore

async def on_init(self, async_ten_env: AsyncTenEnv) -> None:
await super().on_init(async_ten_env)
self.ten_env = async_ten_env

async def on_start(self, async_ten_env: AsyncTenEnv) -> None:
await super().on_start(async_ten_env)
Expand All @@ -70,6 +76,49 @@ async def on_stop(self, async_ten_env: AsyncTenEnv) -> None:
async def on_deinit(self, async_ten_env: AsyncTenEnv) -> None:
await super().on_deinit(async_ten_env)

async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None:
"""Handle incoming data messages."""
data_name = data.get_name()
async_ten_env.log_debug(f"on_data name: {data_name}")
if data_name == DATA_LLM_IN_SET_PROMPT:
prompt_json, _ = data.get_property_to_json(None)
prompt_data = json.loads(prompt_json)
new_prompt = prompt_data.get("prompt", "")
self.dynamic_prompt = new_prompt
async_ten_env.log_info(f"Dynamic prompt set: {new_prompt[:100]}...")
await self.on_prompt_update(new_prompt)

async def on_prompt_update(self, prompt: str) -> None:
"""
Called when a dynamic prompt is set by the prompt_template extension.
Subclasses can override this method to apply the new prompt.
Default implementation does nothing - subclass should update conversation context if needed.
"""
pass

def get_prompt(
self, config_prompt: str, prompt_params: dict | None = None
) -> str:
"""
Get the effective prompt to use.
Returns dynamic_prompt if set by prompt_template extension,
otherwise returns the config_prompt from property.json.

If prompt_params is provided and config_prompt contains {{placeholder}} patterns,
they will be replaced with the corresponding values from prompt_params.
"""
if self.dynamic_prompt is not None:
return self.dynamic_prompt
if not prompt_params:
return config_prompt

def replace_placeholder(match):
key = match.group(1).strip()
return str(prompt_params.get(key, match.group(0)))

pattern = r"\{\{(\w+)\}\}"
return re.sub(pattern, replace_placeholder, config_prompt)

async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
"""
handle default commands
Expand Down
51 changes: 49 additions & 2 deletions interface/ten_ai_base/llm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
from abc import ABC, abstractmethod
import asyncio
import json
import re
import traceback
from typing import AsyncGenerator, Dict, Optional

from .struct import LLMRequest, LLMRequestAbort, LLMRequestRetrievePrompt, LLMResponse, LLMResponseRetrievePrompt
from .llm import DATA_LLM_IN_SET_PROMPT
from .struct import (
LLMRequest,
LLMRequestAbort,
LLMRequestRetrievePrompt,
LLMResponse,
LLMResponseRetrievePrompt,
)
from ten_runtime import (
AsyncExtension,
Data,
)
from ten_runtime.async_ten_env import AsyncTenEnv
from ten_runtime.cmd import Cmd
Expand All @@ -35,6 +44,7 @@ def __init__(self, name: str):
self.ten_env: AsyncTenEnv = None
self._inflight: Dict[str, "AsyncLLM2BaseExtension._TaskCtx"] = {}
self._lock = asyncio.Lock()
self.dynamic_prompt: str | None = None # Dynamic prompt set by prompt_template extension


async def on_init(self, async_ten_env: AsyncTenEnv) -> None:
Expand All @@ -44,6 +54,17 @@ async def on_init(self, async_ten_env: AsyncTenEnv) -> None:
async def on_start(self, async_ten_env: AsyncTenEnv) -> None:
await super().on_start(async_ten_env)

async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None:
"""Handle incoming data messages (dynamic prompt updates)."""
data_name = data.get_name()
async_ten_env.log_debug(f"[LLM2Base] on_data: {data_name}")
if data_name == DATA_LLM_IN_SET_PROMPT:
prompt_json, _ = data.get_property_to_json(None)
prompt_data = json.loads(prompt_json)
new_prompt = prompt_data.get("prompt", "")
self.dynamic_prompt = new_prompt
async_ten_env.log_info(f"[LLM2Base] Dynamic prompt set: {new_prompt[:100]}...")

async def on_stop(self, async_ten_env: AsyncTenEnv) -> None:
await self._cancel_all()
await super().on_stop(async_ten_env)
Expand All @@ -66,6 +87,13 @@ async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
if not rid:
raise RuntimeError("LLMRequest.request_id is required")

# Apply prompt template rendering if prompt is provided in request
prompt_params = getattr(req, "prompt_params", None)
if req.prompt is not None:
req.prompt = self.get_prompt(req.prompt, prompt_params)
elif self.dynamic_prompt is not None:
req.prompt = self.dynamic_prompt

# Reject duplicates instead of replacing
async with self._lock:
existing = self._inflight.get(rid)
Expand Down Expand Up @@ -140,6 +168,25 @@ def __init__(self, task: asyncio.Task, cmd: Cmd, request_id: str):
self.cmd = cmd
self.request_id = request_id

def get_prompt(self, config_prompt: str, prompt_params: dict | None = None) -> str:
"""
Get the effective prompt to use, with optional template rendering.

If prompt_params is provided and config_prompt contains {{placeholder}} patterns,
they will be replaced with the corresponding values from prompt_params.
"""
if self.dynamic_prompt is not None:
return self.dynamic_prompt
if not prompt_params:
return config_prompt

def replace_placeholder(match):
key = match.group(1).strip()
return str(prompt_params.get(key, match.group(0)))

pattern = r"\{\{(\w+)\}\}"
return re.sub(pattern, replace_placeholder, config_prompt)

async def _start_locked(self, ten_env: AsyncTenEnv, cmd: Cmd, req: LLMRequest) -> None:
"""Call with self._lock held. Starts a task and registers it in _inflight."""
rid = req.request_id
Expand Down Expand Up @@ -228,4 +275,4 @@ def on_call_chat_completion(
"""Called when a chat completion is requested by cmd call. Implement this method to process the chat completion."""
raise NotImplementedError(
"on_call_chat_completion must be implemented in the subclass"
)
)
48 changes: 48 additions & 0 deletions interface/ten_ai_base/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import traceback
from typing import final
import uuid
import re

from .struct import MLLMClientCreateResponse, MLLMClientFunctionCallOutput, MLLMClientMessageItem, MLLMClientRegisterTool, MLLMClientSendMessageItem, MLLMClientSetMessageContext, MLLMServerFunctionCall, MLLMServerInputTranscript, MLLMServerOutputTranscript, MLLMServerSessionReady, MLLMServerInterrupt

Expand Down Expand Up @@ -44,6 +45,7 @@
DATA_MLLM_IN_CREATE_RESPONSE = "mllm_client_request_create_response"
DATA_MLLM_IN_REGISTER_TOOL = "mllm_client_request_register_tool"
DATA_MLLM_IN_FUNCTION_CALL_OUTPUT = "mllm_client_function_call_output"
DATA_MLLM_IN_SET_PROMPT = "mllm_client_set_prompt"


class AsyncMLLMBaseExtension(AsyncExtension):
Expand All @@ -60,6 +62,7 @@ def __init__(self, name: str):
self.uuid = self._get_uuid() # Unique identifier for the current final turn
self.leftover_bytes = b""
self.message_context: list[MLLMClientMessageItem] = [] # Context for the current message
self.dynamic_prompt: str | None = None # Dynamic prompt set by prompt_template extension

# States for TTFW calculation
self.first_audio_time: float | None = (
Expand Down Expand Up @@ -102,6 +105,13 @@ async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None:
function_call_output_json, _ = data.get_property_to_json(None)
function_call_output = MLLMClientFunctionCallOutput.model_validate_json(function_call_output_json)
await self.send_client_function_call_output(function_call_output)
elif data_name == DATA_MLLM_IN_SET_PROMPT:
prompt_json, _ = data.get_property_to_json(None)
prompt_data = json.loads(prompt_json)
new_prompt = prompt_data.get("prompt", "")
self.dynamic_prompt = new_prompt
ten_env.log_info(f"Dynamic prompt set: {new_prompt[:100]}...")
await self.on_prompt_update(new_prompt)


async def on_stop(self, ten_env: AsyncTenEnv) -> None:
Expand Down Expand Up @@ -235,6 +245,44 @@ async def send_audio(self, frame: AudioFrame, session_id: str | None) -> bool:
"""
raise NotImplementedError("This method should be implemented in subclasses.")

async def on_prompt_update(self, prompt: str) -> None:
"""
Called when a dynamic prompt is set by the prompt_template extension.
Subclasses can override this method to apply the new prompt.
Default implementation does nothing - subclass should update session if needed.
"""
pass

def get_prompt(self, config_prompt: str, prompt_params: dict | None = None) -> str:
"""
Get the effective prompt to use, with optional template rendering.

If prompt_params is provided and config_prompt contains {{placeholder}} patterns,
they will be replaced with the corresponding values from prompt_params.

Args:
config_prompt: The prompt template from property.json (may contain {{placeholder}})
prompt_params: Optional dict of parameters to render into the template

Returns:
The rendered prompt string
"""
# If dynamic_prompt is set (by prompt_template extension), use it directly
if self.dynamic_prompt is not None:
return self.dynamic_prompt

# If no params provided, return config_prompt as-is
if not prompt_params:
return config_prompt

# Render template with {{placeholder}} syntax
def replace_placeholder(match):
key = match.group(1).strip()
return str(prompt_params.get(key, match.group(0))) # Keep original if not found

pattern = r'\{\{(\w+)\}\}'
return re.sub(pattern, replace_placeholder, config_prompt)

@final
async def send_server_session_ready(self, session: MLLMServerSessionReady) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion interface/ten_ai_base/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class LLMRequest(BaseModel):
tools: Optional[list[LLMToolMetadata]] = None
parameters: Optional[dict[str, Any]] = None
prompt: Optional[str] = None
prompt_params: Optional[dict[str, Any]] = None

class LLMRequestAbort(BaseModel):
"""
Expand Down Expand Up @@ -334,4 +335,4 @@ class MLLMServerFunctionCall(BaseModel):
call_id: str
name: str
arguments: str
metadata: dict[str, Any] = {}
metadata: dict[str, Any] = {}