Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langchain_core.messages import AIMessage, BaseMessageChunk, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.pregel.protocol import PregelProtocol
from langgraph.types import StreamMode
from langgraph.typing import ContextT

from livekit.agents import llm, utils
Expand All @@ -31,6 +32,8 @@
NotGivenOr,
)

_SUPPORTED_MODES: set[StreamMode] = {"messages", "custom"}


class LLMAdapter(llm.LLM, Generic[ContextT]):
def __init__(
Expand All @@ -40,12 +43,20 @@ def __init__(
config: RunnableConfig | None = None,
context: ContextT | None = None,
subgraphs: bool = False,
stream_mode: StreamMode | list[StreamMode] = "messages",
) -> None:
super().__init__()
modes = {stream_mode} if isinstance(stream_mode, str) else set(stream_mode)
unsupported = modes - _SUPPORTED_MODES
if unsupported:
raise ValueError(
f"Unsupported stream mode(s): {unsupported}. Only {_SUPPORTED_MODES} are supported."
)
self._graph = graph
self._config = config
self._context = context
self._subgraphs = subgraphs
self._stream_mode = stream_mode

@property
def model(self) -> str:
Expand Down Expand Up @@ -75,6 +86,7 @@ def chat(
config=self._config,
context=self._context,
subgraphs=self._subgraphs,
stream_mode=self._stream_mode,
)


Expand All @@ -90,6 +102,7 @@ def __init__(
config: RunnableConfig | None = None,
context: ContextT | None = None,
subgraphs: bool = False,
stream_mode: StreamMode | list[StreamMode] = "messages",
):
super().__init__(
llm,
Expand All @@ -101,9 +114,11 @@ def __init__(
self._config = config
self._context = context
self._subgraphs = subgraphs
self._stream_mode = stream_mode

async def _run(self) -> None:
state = self._chat_ctx_to_state()
is_multi_mode = isinstance(self._stream_mode, list)

# Some LangGraph versions don't accept the `subgraphs` or `context` kwargs yet.
# Try with them first; fall back gracefully if unsupported.
Expand All @@ -112,24 +127,51 @@ async def _run(self) -> None:
state,
self._config,
context=self._context,
stream_mode="messages",
stream_mode=self._stream_mode,
subgraphs=self._subgraphs,
)
except TypeError:
aiter = self._graph.astream(
state,
self._config,
stream_mode="messages",
stream_mode=self._stream_mode,
)

async for item in aiter:
token_like = _extract_message_chunk(item)
if token_like is None:
continue
# Multi-mode: item is (mode, data) tuple wrapper
if is_multi_mode and isinstance(item, tuple) and len(item) == 2:
mode, data = item
if isinstance(mode, str):
if mode == "custom":
# data = payload (str, dict, object)
chat_chunk = _to_chat_chunk(data)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
continue
elif mode == "messages":
# data = (token, metadata)
token_like = _extract_message_chunk(data)
if token_like is None:
continue
chat_chunk = _to_chat_chunk(token_like)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
continue

chat_chunk = _to_chat_chunk(token_like)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
# Single-mode: item is data directly (no tuple wrapper)
if self._stream_mode == "custom":
# item = payload (str, dict, object)
chat_chunk = _to_chat_chunk(item)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
elif self._stream_mode == "messages":
# item = (token, metadata)
token_like = _extract_message_chunk(item)
if token_like is None:
continue
chat_chunk = _to_chat_chunk(token_like)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)

def _chat_ctx_to_state(self) -> dict[str, Any]:
"""Convert chat context to langgraph input"""
Expand Down Expand Up @@ -201,6 +243,14 @@ def _to_chat_chunk(msg: str | Any) -> llm.ChatChunk | None:
content = msg.text()
if getattr(msg, "id", None):
message_id = msg.id # type: ignore
elif isinstance(msg, dict):
raw = msg.get("content")
if isinstance(raw, str):
content = raw
elif hasattr(msg, "content"):
raw = msg.content
if isinstance(raw, str):
content = raw

if not content:
return None
Expand Down