Skip to content

Commit 4fd0fe4

Browse files
authored
Merge pull request #65 from pattern-tech/fix/streaming
fix: streaming part updated
2 parents d597b86 + 9c1bcbf commit 4fd0fe4

1 file changed

Lines changed: 121 additions & 62 deletions

File tree

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import asyncio
3+
from typing import Dict, Any, Optional, AsyncGenerator
34

45
from langchain.agents import AgentExecutor
56
from langchain.callbacks.base import BaseCallbackHandler
@@ -13,26 +14,40 @@
1314
class StreamingCallbackHandler(BaseCallbackHandler):
1415
"""
1516
A callback handler that collects tokens and intermediate events in an asyncio queue.
16-
Uses a newline-delimited JSON protocol.
17-
Ensures each event is a complete JSON object with a newline terminator.
17+
Uses a newline-delimited JSON (NDJSON) protocol for reliable streaming.
18+
Each event is a complete JSON object with a newline terminator.
1819
"""
1920

2021
def __init__(self):
2122
self.queue = asyncio.Queue()
2223

2324
def on_llm_new_token(self, token: str, **kwargs) -> None:
25+
"""
26+
Handle new tokens from the LLM.
27+
28+
Args:
29+
token (str): The new token from the LLM.
30+
**kwargs: Additional keyword arguments.
31+
"""
2432
# Create a complete JSON event for each token
2533
event = {"type": "token", "data": token}
26-
# Ensure each event ends with a newline for proper parsing
34+
# Use NDJSON format (newline-delimited JSON)
2735
self.queue.put_nowait(json.dumps(event) + "\n")
2836

2937
def on_agent_action(self, action, **kwargs) -> None:
38+
"""
39+
Handle agent actions.
40+
41+
Args:
42+
action: The action being performed by the agent.
43+
**kwargs: Additional keyword arguments.
44+
"""
3045
event = {
3146
"type": "tool_start",
3247
"tool": getattr(action, "tool", None),
3348
"tool_input": getattr(action, "tool_input", {})
3449
}
35-
# Ensure each event ends with a newline for proper parsing
50+
# Use NDJSON format
3651
self.queue.put_nowait(json.dumps(event) + "\n")
3752

3853

@@ -43,9 +58,23 @@ class RouterAgentService:
4358
"""
4459

4560
def __init__(self, sub_agents, memory=None, streaming: bool = True):
61+
"""
62+
Initialize the RouterAgentService.
63+
64+
Args:
65+
sub_agents: The sub-agents to use for routing.
66+
memory: The memory to use for storing conversation history.
67+
streaming (bool): Whether to enable streaming responses.
68+
"""
4669
self.sub_agents = sub_agents
4770
self.memory = memory
4871
self.streaming = streaming
72+
self.streaming_handler = None
73+
74+
# Default timeout values that can be adjusted if needed
75+
self.token_timeout = 0.01
76+
self.buffer_timeout = 0.005
77+
self.poll_interval = 0.01
4978

5079
# Set up the streaming callback if streaming is enabled.
5180
if streaming:
@@ -87,8 +116,31 @@ def __init__(self, sub_agents, memory=None, streaming: bool = True):
87116
history_messages_key="chat_history",
88117
)
89118

90-
async def stream(self, message: str):
119+
async def _process_complete_json(self, buffer: str) -> tuple[list[str], str]:
91120
"""
121+
Process a buffer to extract complete JSON objects.
122+
123+
Args:
124+
buffer (str): The buffer containing JSON data.
125+
126+
Returns:
127+
tuple: A tuple containing a list of complete JSON strings and any remaining buffer.
128+
"""
129+
results = []
130+
remaining = buffer
131+
132+
# Process all complete objects in the buffer
133+
while "\n" in remaining:
134+
json_str, remaining = remaining.split("\n", 1)
135+
if json_str: # Only include non-empty strings
136+
results.append(json_str + "\n")
137+
138+
return results, remaining
139+
140+
async def stream(self, message: str) -> AsyncGenerator[str, None]:
141+
"""
142+
Stream the agent's response to the input message.
143+
92144
Args:
93145
message (str): The input message to be processed by the agent.
94146
@@ -99,16 +151,17 @@ async def stream(self, message: str):
99151
asyncio.TimeoutError: If waiting for a token from the queue times out.
100152
101153
Notes:
102-
- If memory is enabled, the agent's response is invoked synchronously using `run_in_executor`.
103-
- If memory is not enabled, the agent's response is invoked asynchronously using `arun`.
104-
- The method clears any leftover tokens in the queue before starting to stream the response.
105-
- Uses a buffer to ensure complete JSON objects are sent to prevent parsing errors.
154+
This method uses an efficient NDJSON streaming protocol for reliable parsing.
155+
It supports both memory and non-memory modes, adapting the execution method accordingly.
106156
"""
107-
# Clear any leftover tokens.
157+
if not self.streaming or not self.streaming_handler:
158+
raise ValueError("Streaming is not enabled")
159+
160+
# Clear any leftover tokens
108161
while not self.streaming_handler.queue.empty():
109162
self.streaming_handler.queue.get_nowait()
110163

111-
# If memory is enabled, use the synchronous `invoke` wrapped in run_in_executor.
164+
# Start the agent task based on memory configuration
112165
if self.memory:
113166
loop = asyncio.get_running_loop()
114167
task = loop.run_in_executor(
@@ -123,72 +176,78 @@ async def stream(self, message: str):
123176
self.agent_executor.arun({"input": message})
124177
)
125178

126-
# Use a smaller timeout to ensure more responsive streaming
127-
timeout = 0.01
179+
buffer = "" # Initialize an empty buffer for accumulating incomplete JSON
128180

129-
# Yield tokens as they become available.
181+
# Continue processing while the task is running or queue has items
130182
while not task.done() or not self.streaming_handler.queue.empty():
131183
try:
132-
# Get token with a short timeout to maintain streaming responsiveness
133-
token = await asyncio.wait_for(self.streaming_handler.queue.get(), timeout=timeout)
134-
135-
# Ensure token is a complete JSON object
136-
if token.endswith("\n"):
137-
# Token is already a complete JSON object, yield it directly
138-
yield token
139-
else:
140-
# Token might be incomplete, wait a tiny bit for more data
141-
buffer = token
142-
try:
143-
# Try to get more data with a very short timeout
144-
while not buffer.endswith("\n"):
145-
more_token = await asyncio.wait_for(
146-
self.streaming_handler.queue.get(),
147-
timeout=0.005
148-
)
149-
buffer += more_token
150-
# If we now have a complete line, break
151-
if "\n" in buffer:
152-
break
153-
except asyncio.TimeoutError:
154-
# If we timeout waiting for more data, that's okay
155-
# We'll just yield what we have if it's complete
156-
pass
157-
158-
# Process the buffer to yield complete JSON objects
159-
while "\n" in buffer:
160-
json_str, remaining = buffer.split("\n", 1)
161-
if json_str: # Only yield non-empty strings
162-
yield json_str + "\n"
163-
buffer = remaining
164-
165-
# If there's anything left in the buffer, keep it for next iteration
166-
if buffer:
167-
# Put it back in the queue for the next iteration
168-
self.streaming_handler.queue.put_nowait(buffer)
169-
except asyncio.TimeoutError:
170-
# Short timeout to keep the loop responsive
171-
await asyncio.sleep(0.01)
172-
continue
184+
# Try to get a token with a timeout to maintain responsiveness
185+
token = await asyncio.wait_for(
186+
self.streaming_handler.queue.get(),
187+
timeout=self.token_timeout
188+
)
173189

174-
result = await task
190+
# Add the new token to our buffer
191+
buffer += token
175192

176-
def ask(self, message: str):
193+
# If we have complete JSON objects (ending with newline), process them
194+
if "\n" in buffer:
195+
complete_jsons, buffer = await self._process_complete_json(buffer)
196+
for json_str in complete_jsons:
197+
yield json_str
198+
199+
except asyncio.TimeoutError:
200+
# No new tokens available, wait a bit before checking again
201+
await asyncio.sleep(self.poll_interval)
202+
continue
203+
except Exception as e:
204+
# Handle any parsing or processing errors
205+
error_event = {
206+
"type": "error",
207+
"data": f"Streaming error: {str(e)}"
208+
}
209+
yield json.dumps(error_event) + "\n"
210+
# Continue processing despite errors
211+
212+
# If there's anything left in the buffer after task completion, process it
213+
if buffer:
214+
try:
215+
# Try to parse it as JSON and yield if valid
216+
json.loads(buffer) # This is just a validation check
217+
yield buffer if buffer.endswith("\n") else buffer + "\n"
218+
except json.JSONDecodeError:
219+
# If it's not valid JSON, wrap it in an error event
220+
error_event = {
221+
"type": "error",
222+
"data": f"Invalid JSON in final buffer: {buffer}"
223+
}
224+
yield json.dumps(error_event) + "\n"
225+
226+
# Wait for the task to complete and get the result
227+
try:
228+
await task
229+
except Exception as e:
230+
# Handle any errors during task execution
231+
error_event = {
232+
"type": "error",
233+
"data": f"Task execution error: {str(e)}"
234+
}
235+
yield json.dumps(error_event) + "\n"
236+
237+
def ask(self, message: str) -> Dict[str, Any]:
177238
"""
178239
Sends a message to the agent and returns the response.
179240
180241
Args:
181242
message (str): The message to send to the agent.
182243
183244
Returns:
184-
The response from the agent.
185-
186-
If the agent has memory, it uses the agent with chat history to invoke the response.
187-
Otherwise, it uses the agent executor to invoke the response.
245+
Dict[str, Any]: The response from the agent.
188246
"""
189247
if self.memory:
190248
return self.agent_with_chat_history.invoke(
191249
input={"input": message},
192-
config={"configurable": {"session_id": "ـ"}})
250+
config={"configurable": {"session_id": "ـ"}}
251+
)
193252
else:
194253
return self.agent_executor.invoke({"input": message})

0 commit comments

Comments
 (0)