From 9eb3ef5ea30fcb223379a17941ebaeb2598df72c Mon Sep 17 00:00:00 2001 From: vrtornisiello Date: Tue, 10 Mar 2026 09:30:18 -0300 Subject: [PATCH 1/3] refactor: migrate agent creation to LangChain's create_agent function The new `create_agent` function from `langchain.agents` is the standard way to build agents in LangChain v1. It provides a simpler interface while offering greater customization potential through middleware. Key changes: - Replace custom `ReActAgent` with `create_agent` - Replace custom `trim_messages_before_agent` hook with built-in `SummarizationMiddleware` - Replace custom recursion limit handling with `ModelCallLimitMiddleware` - Update `_process_chunk` to handle the new node names ("model" instead of "agent", "ModelCallLimitMiddleware.before_model" for limit events) --- app/agent/__init__.py | 3 - app/agent/hooks.py | 29 ---- app/agent/react_agent.py | 312 ---------------------------------- app/agent/types.py | 3 - app/api/dependencies/agent.py | 7 +- app/api/streaming.py | 64 ++----- app/main.py | 22 ++- app/settings.py | 7 - 8 files changed, 31 insertions(+), 416 deletions(-) delete mode 100644 app/agent/hooks.py delete mode 100644 app/agent/react_agent.py delete mode 100644 app/agent/types.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index d429377..e69de29 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,3 +0,0 @@ -from .react_agent import ReActAgent - -__all__ = ["ReActAgent"] diff --git a/app/agent/hooks.py b/app/agent/hooks.py deleted file mode 100644 index 57abc2f..0000000 --- a/app/agent/hooks.py +++ /dev/null @@ -1,29 +0,0 @@ -from langchain.messages import RemoveMessage -from langchain_core.messages.base import BaseMessage -from langchain_core.messages.utils import count_tokens_approximately, trim_messages -from langgraph.graph.message import REMOVE_ALL_MESSAGES - -from app.agent.types import StateT -from app.settings import settings - - -def trim_messages_before_agent(state: StateT) -> dict[str, BaseMessage]: - messages = state["messages"] - - # For the first message, skip trimming. If it's too long, let it fail. - if len(messages) == 1: - return {"messages": []} - - # For subsequent turns, trim chat history to fit within token limits. - remaining_messages = trim_messages( - messages, - token_counter=count_tokens_approximately, # The accurate counter is too slow. - max_tokens=settings.MAX_TOKENS, - strategy="last", - start_on="human", - end_on="human", - include_system=True, - allow_partial=False, - ) - - return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *remaining_messages]} diff --git a/app/agent/react_agent.py b/app/agent/react_agent.py deleted file mode 100644 index 340b759..0000000 --- a/app/agent/react_agent.py +++ /dev/null @@ -1,312 +0,0 @@ -from collections.abc import Callable -from typing import ( - Annotated, - AsyncIterator, - Generic, - Iterator, - Literal, - Sequence, - Type, - TypedDict, -) - -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, SystemMessage -from langchain_core.runnables import RunnableConfig, RunnableLambda -from langchain_core.tools import BaseTool, BaseToolkit -from langgraph.checkpoint.postgres import PostgresSaver -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from langgraph.graph.message import add_messages -from langgraph.graph.state import CompiledStateGraph, StateGraph -from langgraph.managed import IsLastStep, RemainingSteps -from langgraph.prebuilt import ToolNode -from loguru import logger - -from app.agent.types import StateT - - -class ReActState(TypedDict): - messages: Annotated[list[BaseMessage], add_messages] - """Message list""" - - is_last_step: IsLastStep - """Flag indicating if the last step has been reached""" - - remaining_steps: RemainingSteps - """Number of remaining steps before reaching the steps limit""" - - -class ReActAgent(Generic[StateT]): - """A LangGraph ReAct Agent.""" - - agent_node = "agent" - tools_node = "tools" - start_hook_node = "start_hook" - - def __init__( - self, - model: BaseChatModel, - tools: Sequence[BaseTool] | BaseToolkit, - state_schema: Type[StateT] = ReActState, - start_hook: Callable[[StateT], dict] | None = None, - system_prompt: SystemMessage | str | None = None, - checkpointer: PostgresSaver | AsyncPostgresSaver | bool | None = None, - ): - if isinstance(tools, BaseToolkit): - self.tools = tools.get_tools() - else: - self.tools = tools - - if isinstance(system_prompt, str): - self.system_message = SystemMessage(system_prompt) - else: - self.system_message = system_prompt - - self.model = model.bind_tools(self.tools) - - if self.system_message: - self.model_runnable = ( - lambda messages: [self.system_message] + messages - ) | self.model - else: - self.model_runnable = self.model - - self.checkpointer = checkpointer - - self.graph = self._compile(state_schema, start_hook) - - def _call_model( - self, state: StateT, config: RunnableConfig - ) -> dict[str, list[BaseMessage]]: - """Calls the LLM on a message list. - - Args: - state (StateT): The graph state. - config (RunnableConfig): A config to use when calling the LLM. - - Returns: - dict[str, list[BaseMessage]]: The updated message list. - """ - messages = state["messages"] - is_last_step = state["is_last_step"] - remaining_steps = state["remaining_steps"] - - response: AIMessage = self.model_runnable.invoke(messages, config) - - if not response.content and not response.tool_calls: - logger.warning("Empty model response, skipping message list update") - return {"messages": []} - - if ( - is_last_step - and response.tool_calls - or remaining_steps < 2 - and response.tool_calls - ): - return { - "messages": [ - AIMessage( - id=response.id, - content=( - "Desculpe, não consegui encontrar uma resposta para a sua pergunta. " - "Por favor, tente reformulá-la ou pergunte algo diferente." - ), - ) - ] - } - - return {"messages": [response]} - - async def _acall_model( - self, state: StateT, config: RunnableConfig - ) -> dict[str, list[BaseMessage]]: - """Asynchronously calls the LLM on a message list. - - Args: - state (StateT): The graph state. - config (RunnableConfig): A config to use when calling the LLM. - - Returns: - dict[str, list[BaseMessage]]: The updated message list. - """ - messages = state["messages"] - is_last_step = state["is_last_step"] - remaining_steps = state["remaining_steps"] - - response: AIMessage = await self.model_runnable.ainvoke(messages, config) - - if not response.content and not response.tool_calls: - logger.warning("Empty model response, skipping message list update") - return {"messages": []} - - if ( - is_last_step - and response.tool_calls - or remaining_steps < 2 - and response.tool_calls - ): - return { - "messages": [ - AIMessage( - id=response.id, - content=( - "Desculpe, não consegui encontrar uma resposta para a sua pergunta. " - "Por favor, tente reformulá-la ou pergunte algo diferente." - ), - ) - ] - } - - return {"messages": [response]} - - def _compile( - self, state_schema: Type[StateT], start_hook: Callable[[StateT], dict] | None - ) -> CompiledStateGraph: - """Compiles the state graph into a LangChain Runnable. - - Args: - state_schema (Type[StateT]): The state graph schema. - start_hook (Callable[[StateT], dict] | None): An optional node to add before the agent node. - Useful for managing long message histories (e.g., message trimming, summarization, etc.). - Must be a callable or a runnable that takes in current graph state and returns a state update. - - Returns: - CompiledStateGraph: The compiled state graph. - """ # noqa: E501 - graph = StateGraph(state_schema) - - graph.add_node( - self.agent_node, RunnableLambda(self._call_model, self._acall_model) - ) - graph.add_node(self.tools_node, ToolNode(self.tools)) - - if start_hook is not None: - graph.add_node(self.start_hook_node, start_hook) - graph.add_edge(self.start_hook_node, self.agent_node) - entrypoint = self.start_hook_node - else: - entrypoint = self.agent_node - - graph.set_entry_point(entrypoint) - graph.add_edge(self.tools_node, self.agent_node) - graph.add_conditional_edges(self.agent_node, _should_continue) - - # The checkpointer is ignored by default when the graph is used as a subgraph - # For more information, visit https://langchain-ai.github.io/langgraph/how-tos/subgraph-persistence - # If you want to persist the subgraph state between runs, you must use checkpointer=True - # For more information, visit https://github.com/langchain-ai/langgraph/issues/3020 - return graph.compile(self.checkpointer) - - def invoke(self, input: dict, config: RunnableConfig | None = None) -> StateT: - """Runs the compiled graph with an optional configuration. - - Args: - input (dict): The input data for the graph. - config (RunnableConfig | None, optional): The configuration. Defaults to `None`. - - Returns: - StateT: The last output of the graph run. - """ - return self.graph.invoke(input=input, config=config) - - async def ainvoke( - self, input: dict, config: RunnableConfig | None = None - ) -> StateT: - """Asynchronously runs the compiled graph with an optional configuration. - - Args: - input (dict): The input data for the graph. - config (RunnableConfig | None, optional): The configuration. Defaults to `None`. - - Returns: - StateT: The last output of the graph run. - """ - return await self.graph.ainvoke( - input=input, - config=config, - ) - - def stream( - self, - input: dict, - config: RunnableConfig | None = None, - stream_mode: list[str] | None = None, - ) -> Iterator[dict | tuple]: - """Stream graph steps. - - Args: - input (dict): The input data for the graph. - config (RunnableConfig | None, optional): Optional configuration for the agent execution. Defaults to `None`. - stream_mode (list[str] | None, optional): The mode to stream output. See the LangGraph streaming guide in - https://langchain-ai.github.io/langgraph/how-tos/streaming for more details. Defaults to `None`. - - Yields: - dict|tuple: The output for each step in the graph. Its type, shape and content depends on the `stream_mode` arg. - """ - for chunk in self.graph.stream( - input=input, - config=config, - stream_mode=stream_mode, - ): - yield chunk - - async def astream( - self, - input: dict, - config: RunnableConfig | None = None, - stream_mode: list[str] | None = None, - ) -> AsyncIterator[dict | tuple]: - """Asynchronously stream graph steps. - - Args: - input (dict): The input data for the graph. - config (RunnableConfig | None, optional): Optional configuration for the agent execution. Defaults to `None`. - stream_mode (list[str] | None, optional): The mode to stream output. See the LangGraph streaming guide in - https://langchain-ai.github.io/langgraph/how-tos/streaming for more details. Defaults to `None`. - - Yields: - dict|tuple: The output for each step in the graph. Its type, shape and content depends on the `stream_mode` arg. - """ - async for chunk in self.graph.astream( - input=input, - config=config, - stream_mode=stream_mode, - ): - yield chunk - - # Unfortunately, there is no clean way to delete an agent's memory - # except by deleting its checkpoints, as noted in this github discussion: - # https://github.com/langchain-ai/langgraph/discussions/912 - def clear_thread(self, thread_id: str): - """Deletes all checkpoints for a given thread. - - Args: - thread_id (str): The thread unique identifier. - """ - if self.checkpointer is not None: - self.checkpointer.delete_thread(thread_id) - - async def aclear_thread(self, thread_id: str): - """Asynchronously deletes all checkpoints for a given thread. - - Args: - thread_id (str): The thread unique identifier. - """ - if self.checkpointer is not None: - await self.checkpointer.adelete_thread(thread_id) - - -def _should_continue(state: StateT) -> Literal["tools", "__end__"]: - """Routes to the tools node if the last message has any tool calls. - Otherwise, routes to the message pruning node. - - Args: - state (StateT): The graph state. - - Returns: - str: The next node to route to. - """ - last_message = state["messages"][-1] - if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: - return "tools" - return "__end__" diff --git a/app/agent/types.py b/app/agent/types.py deleted file mode 100644 index edb9a6f..0000000 --- a/app/agent/types.py +++ /dev/null @@ -1,3 +0,0 @@ -from typing import TypeVar - -StateT = TypeVar("StateT") diff --git a/app/api/dependencies/agent.py b/app/api/dependencies/agent.py index 283aec3..ef0fd53 100644 --- a/app/api/dependencies/agent.py +++ b/app/api/dependencies/agent.py @@ -1,12 +1,11 @@ from typing import Annotated from fastapi import Depends, Request +from langgraph.graph.state import CompiledStateGraph -from app.agent import ReActAgent - -def get_agent(request: Request) -> ReActAgent: +def get_agent(request: Request) -> CompiledStateGraph: return request.app.state.agent -Agent = Annotated[ReActAgent, Depends(get_agent)] +Agent = Annotated[CompiledStateGraph, Depends(get_agent)] diff --git a/app/api/streaming.py b/app/api/streaming.py index 53da1c9..9dad56c 100644 --- a/app/api/streaming.py +++ b/app/api/streaming.py @@ -2,10 +2,7 @@ import uuid from typing import Any, AsyncIterator, Literal -from google.api_core import exceptions as google_api_exceptions -from langchain.chat_models import init_chat_model from langchain_core.messages import AIMessage, ToolMessage -from langgraph.errors import GraphRecursionError from langgraph.graph.state import CompiledStateGraph from loguru import logger from pydantic import BaseModel, JsonValue @@ -13,7 +10,6 @@ from app.api.schemas import ConfigDict from app.db.database import AsyncDatabase from app.db.models import Message, MessageCreate, MessageRole, MessageStatus -from app.settings import settings class ToolCall(BaseModel): @@ -62,15 +58,9 @@ class ErrorMessage: "Se o problema persistir, avise-nos. Obrigado pela paciência!" ) - CONTEXT_OVERFLOW = ( - "Sua última mensagem ultrapassou o limite de tamanho para esta conversa. " - "Por favor, tente dividir sua solicitação em partes menores " - "ou inicie uma nova conversa." - ) - - GRAPH_RECURSION_LIMIT_REACHED = ( - "Desculpe, não consegui encontrar uma resposta para a sua pergunta. " - "Por favor, tente reformulá-la ou pergunte algo diferente." + MODEL_CALL_LIMIT_REACHED = ( + "Ops, essa pergunta gerou um raciocínio muito longo e não consegui chegar a uma conclusão. " + "Por favor, tente ser mais específico ou divida sua pergunta em partes menores." ) @@ -146,8 +136,8 @@ def _process_chunk(chunk: dict[str, Any]) -> StreamEvent | None: - "final_answer" for agent messages without tool calls - None for ignored chunks """ - if "agent" in chunk: - ai_messages: list[AIMessage] = chunk["agent"]["messages"] + if "model" in chunk: + ai_messages: list[AIMessage] = chunk["model"]["messages"] # If no messages are returned, the model returned an empty response # with no tool calls. This also counts as a final (but empty) answer. @@ -202,6 +192,11 @@ def _process_chunk(chunk: dict[str, Any]) -> StreamEvent | None: return StreamEvent( type="tool_output", data=EventData(tool_outputs=tool_outputs) ) + elif "ModelCallLimitMiddleware.before_model" in chunk: + event_data = EventData( + content=ErrorMessage.MODEL_CALL_LIMIT_REACHED, tool_calls=None + ) + return StreamEvent(type="final_answer", data=event_data) return None @@ -225,7 +220,6 @@ async def stream_response( """ events = [] artifacts = [] - agent_state = None assistant_message = "" status = MessageStatus.SUCCESS @@ -236,7 +230,6 @@ async def stream_response( stream_mode=["updates", "values"], ): if mode == "values": - agent_state = chunk continue event = _process_chunk(chunk) @@ -253,47 +246,10 @@ async def stream_response( events.append(event.model_dump()) yield event.to_sse() - - except GraphRecursionError: - logger.warning(f"Graph recursion limit reached for message {config['run_id']}") - - assistant_message = ErrorMessage.GRAPH_RECURSION_LIMIT_REACHED - - status = MessageStatus.SUCCESS - - yield StreamEvent( - type="final_answer", data=EventData(content=assistant_message) - ).to_sse() - - except google_api_exceptions.InvalidArgument: - logger.exception( - "Agent execution failed with Google API InvalidArgument error:" - ) - - assistant_message = ErrorMessage.UNEXPECTED - - status = MessageStatus.ERROR - - if agent_state is not None: - model = init_chat_model(settings.MODEL_URI) - total_tokens = model.get_num_tokens_from_messages(agent_state["messages"]) - - if total_tokens >= model.profile.get( - "max_input_tokens", settings.MAX_TOKENS - ): - assistant_message = ErrorMessage.CONTEXT_OVERFLOW - - yield StreamEvent( - type="error", data=EventData(error_details={"message": assistant_message}) - ).to_sse() - except Exception: logger.exception(f"Unexpected error responding message {config['run_id']}:") - assistant_message = ErrorMessage.UNEXPECTED - status = MessageStatus.ERROR - yield StreamEvent( type="error", data=EventData(error_details={"message": assistant_message}) ).to_sse() diff --git a/app/main.py b/app/main.py index 3f2d5ea..524bc1c 100644 --- a/app/main.py +++ b/app/main.py @@ -2,14 +2,17 @@ from fastapi import FastAPI from fastapi.responses import RedirectResponse +from langchain.agents import create_agent +from langchain.agents.middleware import ( + ModelCallLimitMiddleware, + SummarizationMiddleware, +) from langchain.chat_models import init_chat_model from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from loguru import logger from psycopg.rows import dict_row from psycopg_pool import AsyncConnectionPool -from app.agent import ReActAgent -from app.agent.hooks import trim_messages_before_agent from app.agent.prompts import SYSTEM_PROMPT from app.agent.tools import BDToolkit from app.api.main import api_router @@ -52,6 +55,17 @@ async def lifespan(app: FastAPI): # pragma: no cover credentials=settings.GOOGLE_CREDENTIALS, ) + summ_middleware = SummarizationMiddleware( + model=model, + trigger=("fraction", 0.5), + keep=("fraction", 0.25), + ) + + limit_middleware = ModelCallLimitMiddleware( + run_limit=20, + exit_behavior="end", + ) + async with AsyncConnectionPool( conninfo=settings.DB_URL, kwargs=conn_kwargs, @@ -62,11 +76,11 @@ async def lifespan(app: FastAPI): # pragma: no cover ) as pool: checkpointer = AsyncPostgresSaver(pool) - agent = ReActAgent( + agent = create_agent( model=model, tools=BDToolkit.get_tools(), - start_hook=trim_messages_before_agent, system_prompt=SYSTEM_PROMPT, + middleware=[summ_middleware, limit_middleware], checkpointer=checkpointer, ) diff --git a/app/settings.py b/app/settings.py index 3ff83d8..e64a2a7 100644 --- a/app/settings.py +++ b/app/settings.py @@ -111,13 +111,6 @@ def GOOGLE_CREDENTIALS(self) -> Credentials: # pragma: no cover "lower ones make them more deterministic." ) ) - MAX_TOKENS: int = Field( - description=( - "Maximum number of tokens for a single input. " - "Must be defined according to the provided model uri " - "and be less than the model's maximum input tokens." - ), - ) # ============================================================ # == LangSmith settings == From 3b5d7abfe40b550b85386329ee548c2d43f58b09 Mon Sep 17 00:00:00 2001 From: vrtornisiello Date: Tue, 10 Mar 2026 09:46:51 -0300 Subject: [PATCH 2/3] chore: update tests for streaming module --- tests/app/api/test_streaming.py | 116 ++++---------------------------- 1 file changed, 14 insertions(+), 102 deletions(-) diff --git a/tests/app/api/test_streaming.py b/tests/app/api/test_streaming.py index 29667ed..63b30e1 100644 --- a/tests/app/api/test_streaming.py +++ b/tests/app/api/test_streaming.py @@ -6,8 +6,6 @@ import pytest from google.api_core import exceptions as google_api_exceptions from langchain_core.messages import AIMessage, ToolMessage -from langgraph.errors import GraphRecursionError -from pytest_mock import MockerFixture from app.api.schemas import ConfigDict from app.api.streaming import ( @@ -119,7 +117,7 @@ class TestProcessChunk: def test_agent_chunk_with_tool_calls(self): """Test agent chunk with tool calls returns tool_call event.""" chunk = { - "agent": { + "model": { "messages": [ AIMessage( content="Let me search for that.", @@ -154,7 +152,7 @@ def test_agent_chunk_with_tool_calls(self): def test_agent_chunk_with_multiple_tool_calls(self): """Test agent chunk with multiple parallel tool calls.""" chunk = { - "agent": { + "model": { "messages": [ AIMessage( content="I'll search both.", @@ -181,7 +179,7 @@ def test_agent_chunk_with_multiple_tool_calls(self): def test_agent_chunk_final_answer(self): """Test agent chunk without tool calls returns final_answer event.""" - chunk = {"agent": {"messages": [AIMessage(content="Here is your answer.")]}} + chunk = {"model": {"messages": [AIMessage(content="Here is your answer.")]}} event = _process_chunk(chunk) @@ -195,7 +193,7 @@ def test_agent_chunk_final_answer(self): def test_agent_chunk_empty_messages(self): """Test agent chunk with empty messages list returns empty final_answer.""" - chunk = {"agent": {"messages": []}} + chunk = {"model": {"messages": []}} event = _process_chunk(chunk) @@ -383,7 +381,7 @@ async def mock_astream(*args, **kwargs): yield ( "updates", { - "agent": { + "model": { "messages": [ AIMessage( content="Let me search.", @@ -433,7 +431,7 @@ async def mock_astream(*args, **kwargs): yield "values", {"messages": ["msg1", "msg2"]} yield ( "updates", - {"agent": {"messages": [AIMessage(content="Here is your answer.")]}}, + {"model": {"messages": [AIMessage(content="Here is your answer.")]}}, ) yield "values", {"messages": ["msg1", "msg2", "msg3"]} @@ -499,7 +497,7 @@ async def mock_astream(*args, **kwargs): assert call_args.status == MessageStatus.ERROR assert call_args.content == ErrorMessage.UNEXPECTED - async def test_stream_response_graph_recursion_error( + async def test_stream_response_model_call_limit_reached( self, mock_database, mock_user_message, @@ -507,12 +505,14 @@ async def test_stream_response_graph_recursion_error( mock_thread_id, mock_model_uri, ): - """Test GraphRecursionError sets graceful message without error status.""" + """Test ModelCallLimitMiddleware sets graceful message without error status.""" mock_agent = MagicMock() async def mock_astream(*args, **kwargs): - raise GraphRecursionError("Recursion limit reached") - yield # Makes this an async generator + yield ( + "updates", + {"ModelCallLimitMiddleware.before_model": {"messages": []}}, + ) mock_agent.astream = mock_astream @@ -529,12 +529,12 @@ async def mock_astream(*args, **kwargs): assert len(events) == 2 assert '"type":"final_answer"' in events[0] - assert ErrorMessage.GRAPH_RECURSION_LIMIT_REACHED in events[0] + assert ErrorMessage.MODEL_CALL_LIMIT_REACHED in events[0] assert '"type":"complete"' in events[1] call_args = mock_database.create_message.call_args[0][0] assert call_args.status == MessageStatus.SUCCESS - assert call_args.content == ErrorMessage.GRAPH_RECURSION_LIMIT_REACHED + assert call_args.content == ErrorMessage.MODEL_CALL_LIMIT_REACHED async def test_stream_response_google_api_error( self, @@ -571,91 +571,3 @@ async def mock_astream(*args, **kwargs): call_args = mock_database.create_message.call_args[0][0] assert call_args.status == MessageStatus.ERROR assert call_args.content == ErrorMessage.UNEXPECTED - - async def test_stream_response_google_api_error_with_agent_state_below_limit( - self, - mocker: MockerFixture, - mock_database, - mock_user_message, - mock_config, - mock_thread_id, - mock_model_uri, - ): - """Test Google API error with agent_state set but tokens below limit.""" - mock_agent = MagicMock() - - async def mock_astream(*args, **kwargs): - yield "values", {"messages": ["msg1"]} # Sets agent_state - raise google_api_exceptions.InvalidArgument("Some other error") - - mock_agent.astream = mock_astream - - mock_model = MagicMock() - mock_model.get_num_tokens_from_messages.return_value = 999 # Below limit - mock_model.profile.get.return_value = 1_048_576 # Gemini context window - mocker.patch("app.api.streaming.init_chat_model", return_value=mock_model) - - events = await self._collect_events( - stream_response( - database=mock_database, - agent=mock_agent, - user_message=mock_user_message, - config=mock_config, - thread_id=mock_thread_id, - model_uri=mock_model_uri, - ) - ) - - assert len(events) == 2 - assert '"type":"error"' in events[0] - assert ErrorMessage.UNEXPECTED in events[0] # Not CONTEXT_OVERFLOW - assert '"type":"complete"' in events[1] - - call_args = mock_database.create_message.call_args[0][0] - assert call_args.status == MessageStatus.ERROR - assert call_args.content == ErrorMessage.UNEXPECTED - - async def test_stream_response_google_api_error_with_agent_state_context_overflow( - self, - mocker: MockerFixture, - mock_database, - mock_user_message, - mock_config, - mock_thread_id, - mock_model_uri, - ): - """Test Google API error with context window exceeded.""" - mock_agent = MagicMock() - - async def mock_astream(*args, **kwargs): - yield "values", {"messages": ["msg1"]} # Sets agent_state - raise google_api_exceptions.InvalidArgument("Token limit exceeded") - - mock_agent.astream = mock_astream - - mock_model = MagicMock() - mock_model.get_num_tokens_from_messages.return_value = ( - 9_999_999 # Exceeds limit - ) - mock_model.profile.get.return_value = 1_048_576 # Gemini context window - mocker.patch("app.api.streaming.init_chat_model", return_value=mock_model) - - events = await self._collect_events( - stream_response( - database=mock_database, - agent=mock_agent, - user_message=mock_user_message, - config=mock_config, - thread_id=mock_thread_id, - model_uri=mock_model_uri, - ) - ) - - assert len(events) == 2 - assert '"type":"error"' in events[0] - assert ErrorMessage.CONTEXT_OVERFLOW in events[0] - assert '"type":"complete"' in events[1] - - call_args = mock_database.create_message.call_args[0][0] - assert call_args.status == MessageStatus.ERROR - assert call_args.content == ErrorMessage.CONTEXT_OVERFLOW From 5f625632de2bdb3ee97db33de671f8eb1a64c59f Mon Sep 17 00:00:00 2001 From: vrtornisiello Date: Tue, 10 Mar 2026 09:52:26 -0300 Subject: [PATCH 3/3] docs: update .env example file --- .env.example | 1 - 1 file changed, 1 deletion(-) diff --git a/.env.example b/.env.example index eef0a6c..cb28d4b 100644 --- a/.env.example +++ b/.env.example @@ -39,7 +39,6 @@ GOOGLE_SERVICE_ACCOUNT=/app/credentials/chatbot-sa.json # ============================================================ MODEL_URI=google_genai:gemini-2.5-flash MODEL_TEMPERATURE=0.2 -MAX_TOKENS=524288 # ============================================================ # == LangSmith settings ==