Skip to content

Commit 2da4416

Browse files
authored
Merge pull request #1428 from 3clyp50/dirtyjson
Dispatch tool calls at first completed JSON object
2 parents fbf6a8d + 5a22235 commit 2da4416

6 files changed

Lines changed: 266 additions & 40 deletions

File tree

agent.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ async def monologue(self):
388388
self.context.streaming_agent = self # mark self as current streamer
389389
self.loop_data.iteration += 1
390390
self.loop_data.params_temporary = {} # clear temporary params
391+
last_response_stream_full = ""
391392

392393
# call message_loop_start extensions
393394
await extension.call_extensions_async(
@@ -425,12 +426,32 @@ async def reasoning_callback(chunk: str, full: str):
425426
await self.handle_reasoning_stream(stream_data["full"])
426427

427428
async def stream_callback(chunk: str, full: str):
429+
nonlocal last_response_stream_full
428430
await self.handle_intervention()
429431
# output the agent response stream
430432
if chunk == full:
431433
printer.print("Response: ") # start of response
432434
# Pass chunk and full data to extensions for processing
433435
stream_data = {"chunk": chunk, "full": full}
436+
stop_response: str | None = None
437+
438+
snapshot = extract_tools.extract_json_root_string(full)
439+
if snapshot:
440+
parsed_snapshot = extract_tools.json_parse_dirty(snapshot)
441+
if parsed_snapshot is not None:
442+
try:
443+
await self.validate_tool_request(parsed_snapshot)
444+
except Exception:
445+
pass
446+
else:
447+
previous_full = last_response_stream_full
448+
stream_data["full"] = snapshot
449+
if snapshot.startswith(previous_full):
450+
stream_data["chunk"] = snapshot[len(previous_full) :]
451+
else:
452+
stream_data["chunk"] = snapshot
453+
stop_response = snapshot
454+
434455
await extension.call_extensions_async(
435456
"response_stream_chunk",
436457
self,
@@ -442,6 +463,9 @@ async def stream_callback(chunk: str, full: str):
442463
printer.stream(stream_data["chunk"])
443464
# Use the potentially modified full text for downstream processing
444465
await self.handle_response_stream(stream_data["full"])
466+
last_response_stream_full = stream_data["full"]
467+
if stop_response is not None:
468+
return stop_response
445469

446470
# call main LLM
447471
agent_response, _reasoning = await self.call_chat_model(
@@ -770,7 +794,7 @@ async def stream_callback(chunk: str, total: str):
770794
async def call_chat_model(
771795
self,
772796
messages: list[BaseMessage],
773-
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
797+
response_callback: Callable[[str, str], Awaitable[str | None]] | None = None,
774798
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
775799
background: bool = False,
776800
explicit_caching: bool = True,

helpers/dirty_json.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def _reset(self):
2525
self.current_char = None
2626
self.result = None
2727
self.stack = []
28+
self.completed = False
29+
self._parsing_started = False
30+
31+
def _pop_stack(self, root_closed: bool = False):
32+
"""Pop from the parsing stack and mark completed only on an explicit root close."""
33+
self.stack.pop()
34+
if root_closed and self._parsing_started and not self.stack:
35+
self.completed = True
2836

2937
@staticmethod
3038
def parse_string(json_string):
@@ -95,13 +103,17 @@ def _skip_multi_line_comment(self):
95103
self._advance()
96104

97105
def _parse(self):
106+
if self.completed and not self.stack:
107+
return
98108
if self.result is None:
99109
self.result = self._parse_value()
100110
else:
101111
self._continue_parsing()
102112

103113
def _continue_parsing(self):
104114
while self.current_char is not None:
115+
if self.completed and not self.stack:
116+
return
105117
if isinstance(self.result, dict):
106118
self._parse_object_content()
107119
elif isinstance(self.result, list):
@@ -114,7 +126,9 @@ def _continue_parsing(self):
114126
def _parse_value(self):
115127
self._skip_whitespace()
116128
if self.current_char == "{":
117-
if self._peek(1) == "{": # Handle {{
129+
# Only treat doubled braces as a wrapper at the root; nested objects
130+
# must keep their closing braces paired correctly.
131+
if not self.stack and self._peek(1) == "{": # Handle {{
118132
self._advance(2)
119133
return self._parse_object()
120134
elif self.current_char == "[":
@@ -153,21 +167,24 @@ def _parse_object(self):
153167
obj = {}
154168
self._advance() # Skip opening brace
155169
self.stack.append(obj)
170+
self._parsing_started = True
156171
self._parse_object_content()
157172
return obj
158173

159174
def _parse_object_content(self):
160175
while self.current_char is not None:
161176
self._skip_whitespace()
162177
if self.current_char == "}":
163-
if self._peek(1) == "}": # Handle }}
178+
# Root-level wrapper outputs may end in "}}"; nested objects must
179+
# still close one brace at a time.
180+
if len(self.stack) == 1 and self._peek(1) == "}": # Handle }}
164181
self._advance(2)
165182
else:
166183
self._advance()
167-
self.stack.pop()
184+
self._pop_stack(root_closed=True)
168185
return
169186
if self.current_char is None:
170-
self.stack.pop()
187+
self._pop_stack()
171188
return # End of input reached while parsing object
172189

173190
key = self._parse_key()
@@ -190,7 +207,7 @@ def _parse_object_content(self):
190207
continue
191208
elif self.current_char != "}":
192209
if self.current_char is None:
193-
self.stack.pop()
210+
self._pop_stack()
194211
return # End of input reached after value
195212
continue
196213

@@ -216,6 +233,7 @@ def _parse_array(self):
216233
arr = []
217234
self._advance() # Skip opening bracket
218235
self.stack.append(arr)
236+
self._parsing_started = True
219237
self._parse_array_content()
220238
return arr
221239

@@ -224,7 +242,7 @@ def _parse_array_content(self):
224242
self._skip_whitespace()
225243
if self.current_char == "]":
226244
self._advance()
227-
self.stack.pop()
245+
self._pop_stack(root_closed=True)
228246
return
229247
value = self._parse_value()
230248
self.stack[-1].append(value)
@@ -236,10 +254,10 @@ def _parse_array_content(self):
236254
if self.current_char is None or self.current_char == "]":
237255
if self.current_char == "]":
238256
self._advance()
239-
self.stack.pop()
257+
self._pop_stack(root_closed=True)
240258
return
241259
elif self.current_char != "]":
242-
self.stack.pop()
260+
self._pop_stack()
243261
return
244262

245263
def _parse_string(self):

helpers/extract_tools.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,28 @@ def json_parse_dirty(json: str) -> dict[str, Any] | None:
1919
return None
2020
return None
2121

22+
def extract_json_root_string(content: str) -> str | None:
23+
if not content or not isinstance(content, str):
24+
return None
25+
26+
start = content.find("{")
27+
if start == -1:
28+
return None
29+
first_array = content.find("[")
30+
if first_array != -1 and first_array < start:
31+
return None
32+
33+
parser = DirtyJson()
34+
try:
35+
parser.parse(content[start:])
36+
except Exception:
37+
return None
38+
39+
if not parser.completed:
40+
return None
41+
42+
return content[start : start + parser.index]
43+
2244

2345
def extract_json_object_string(content):
2446
start = content.find("{")

models.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ async def unified_call(
475475
system_message="",
476476
user_message="",
477477
messages: List[BaseMessage] | None = None,
478-
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
478+
response_callback: Callable[[str, str], Awaitable[str | None]] | None = None,
479479
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
480480
tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
481481
rate_limiter_callback: (
@@ -526,36 +526,46 @@ async def unified_call(
526526

527527
if stream:
528528
# iterate over chunks
529-
async for chunk in _completion: # type: ignore
530-
got_any_chunk = True
531-
# parse chunk
532-
parsed = _parse_chunk(chunk)
533-
output = result.add_chunk(parsed)
534-
535-
# collect reasoning delta and call callbacks
536-
if output["reasoning_delta"]:
537-
if reasoning_callback:
538-
await reasoning_callback(output["reasoning_delta"], result.reasoning)
539-
if tokens_callback:
540-
await tokens_callback(
541-
output["reasoning_delta"],
542-
approximate_tokens(output["reasoning_delta"]),
543-
)
544-
# Add output tokens to rate limiter if configured
545-
if limiter:
546-
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
547-
# collect response delta and call callbacks
548-
if output["response_delta"]:
549-
if response_callback:
550-
await response_callback(output["response_delta"], result.response)
551-
if tokens_callback:
552-
await tokens_callback(
553-
output["response_delta"],
554-
approximate_tokens(output["response_delta"]),
555-
)
556-
# Add output tokens to rate limiter if configured
557-
if limiter:
558-
limiter.add(output=approximate_tokens(output["response_delta"]))
529+
stop_response: str | None = None
530+
try:
531+
async for chunk in _completion: # type: ignore
532+
got_any_chunk = True
533+
# parse chunk
534+
parsed = _parse_chunk(chunk)
535+
output = result.add_chunk(parsed)
536+
537+
# collect reasoning delta and call callbacks
538+
if output["reasoning_delta"]:
539+
if reasoning_callback:
540+
await reasoning_callback(output["reasoning_delta"], result.reasoning)
541+
if tokens_callback:
542+
await tokens_callback(
543+
output["reasoning_delta"],
544+
approximate_tokens(output["reasoning_delta"]),
545+
)
546+
# Add output tokens to rate limiter if configured
547+
if limiter:
548+
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
549+
# collect response delta and call callbacks
550+
if output["response_delta"]:
551+
if response_callback:
552+
stop_response = await response_callback(
553+
output["response_delta"], result.response
554+
)
555+
if tokens_callback:
556+
await tokens_callback(
557+
output["response_delta"],
558+
approximate_tokens(output["response_delta"]),
559+
)
560+
# Add output tokens to rate limiter if configured
561+
if limiter:
562+
limiter.add(output=approximate_tokens(output["response_delta"]))
563+
if stop_response is not None:
564+
result.response = stop_response
565+
break
566+
finally:
567+
if stop_response is not None and hasattr(_completion, "aclose"):
568+
await _completion.aclose() # type: ignore[attr-defined]
559569

560570
# non-stream response
561571
else:

tests/test_dirty_json.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
9+
if str(PROJECT_ROOT) not in sys.path:
10+
sys.path.insert(0, str(PROJECT_ROOT))
11+
12+
from helpers.dirty_json import DirtyJson
13+
14+
15+
@pytest.mark.parametrize(
16+
("payload", "expected"),
17+
[
18+
(
19+
'{"tool_name":"x","tool_args":{}}',
20+
{"tool_name": "x", "tool_args": {}},
21+
),
22+
("[1, 2, 3]", [1, 2, 3]),
23+
],
24+
)
25+
def test_completed_true_when_root_is_explicitly_closed(payload, expected) -> None:
26+
parser = DirtyJson()
27+
28+
assert parser.parse(payload) == expected
29+
assert parser.completed is True
30+
31+
32+
def test_completed_false_when_root_hits_eof_before_closing() -> None:
33+
parser = DirtyJson()
34+
35+
assert parser.parse('{"tool_name":"x","tool_args":{}') == {
36+
"tool_name": "x",
37+
"tool_args": {},
38+
}
39+
assert parser.completed is False
40+
41+
42+
def test_completed_remains_true_after_trailing_content() -> None:
43+
parser = DirtyJson()
44+
45+
assert parser.feed('{"tool_name":"x","tool_args":{}}') == {
46+
"tool_name": "x",
47+
"tool_args": {},
48+
}
49+
assert parser.completed is True
50+
51+
assert parser.feed(" trailing noise") == {
52+
"tool_name": "x",
53+
"tool_args": {},
54+
}
55+
56+
assert parser.completed is True

0 commit comments

Comments
 (0)