Skip to content

Commit e9a18bb

Browse files
committed
feat: improve agent example
1 parent aa3f974 commit e9a18bb

4 files changed

Lines changed: 200 additions & 75 deletions

File tree

examples/agent.py

Lines changed: 189 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@
22

33
import argparse
44
import asyncio
5-
from collections.abc import AsyncGenerator
5+
import os
6+
import random
7+
import readline
68
from pathlib import Path
7-
from typing import TYPE_CHECKING, cast
9+
from typing import TYPE_CHECKING, Literal, cast
810
from typing_extensions import override
911

10-
from openai import AsyncOpenAI
12+
from openai import AsyncOpenAI, pydantic_function_tool
1113
from openai.lib.streaming.chat import ChatCompletionStreamState
1214
from openai.types.chat import (
13-
ChatCompletionChunk,
15+
ChatCompletion,
1416
ChatCompletionMessageParam,
17+
ChatCompletionMessageToolCallUnion,
1518
)
16-
from pydantic import TypeAdapter
19+
from pydantic import BaseModel, Field, TypeAdapter
20+
from rich.console import Console
1721

1822
import duron
19-
from duron import RunOptions, op
23+
from duron import RunOptions
2024
from duron.codec import Codec
2125
from duron.contrib.storage import FileLogStorage
2226

2327
if TYPE_CHECKING:
2428
from typing import Any
2529

26-
from openai.types.chat import (
27-
ParsedChatCompletionMessage,
28-
)
29-
3030
from duron.codec import JSONValue
3131
from duron.typing import TypeHint
3232

@@ -52,22 +52,90 @@ def decode_json(self, encoded: JSONValue, expected_type: TypeHint[Any]) -> objec
5252
return cast("object", TypeAdapter(expected_type).validate_python(encoded))
5353

5454

55+
@duron.op
56+
async def do_input() -> str: # noqa: RUF029
57+
try:
58+
return input("> ") # noqa: ASYNC250
59+
except EOFError:
60+
os._exit(0)
61+
except KeyboardInterrupt:
62+
os._exit(1)
63+
64+
5565
@duron.fn(codec=PydanticCodec())
5666
async def agent_fn(ctx: duron.Context) -> None:
57-
result = await completion(
58-
ctx,
59-
messages=[
60-
{
61-
"role": "system",
62-
"content": "You are a helpful assistant!",
63-
},
64-
{
65-
"role": "user",
66-
"content": "Say hello to Duron.",
67-
},
68-
],
69-
)
70-
print(result.content)
67+
console = Console()
68+
history: list[ChatCompletionMessageParam] = [
69+
{
70+
"role": "system",
71+
"content": "You are a helpful assistant!",
72+
},
73+
]
74+
while True:
75+
msg = await ctx.run(do_input)
76+
history.append({
77+
"role": "user",
78+
"content": msg,
79+
})
80+
console.print("[bold cyan] USER[/bold cyan]", msg)
81+
while True:
82+
result = await completion(
83+
ctx,
84+
messages=history,
85+
)
86+
if result.choices[0].message.content:
87+
console.print(
88+
"[bold red]ASSISTANT[/bold red] ", result.choices[0].message.content
89+
)
90+
history.append({
91+
"role": "assistant",
92+
"content": result.choices[0].message.content,
93+
"tool_calls": [
94+
{
95+
"id": toolcall.id,
96+
"type": "function",
97+
"function": {
98+
"name": toolcall.function.name,
99+
"arguments": toolcall.function.arguments,
100+
},
101+
}
102+
for toolcall in result.choices[0].message.tool_calls or []
103+
if toolcall.type == "function"
104+
],
105+
})
106+
if not result.choices[0].message.tool_calls:
107+
break
108+
109+
tasks: list[asyncio.Task[tuple[str, str]]] = []
110+
for tool_call in result.choices[0].message.tool_calls:
111+
console.print("[bold yellow] CALL[/bold yellow]", tool_call.id)
112+
console.print(tool_call.model_dump_json())
113+
tasks.append(asyncio.create_task(ctx.run(call_tool, None, tool_call)))
114+
for id_, tool_result in await asyncio.gather(*tasks):
115+
console.print("[bold cyan] TOOL[/bold cyan]", id_)
116+
console.print(tool_result)
117+
history.append({
118+
"role": "tool",
119+
"tool_call_id": id_,
120+
"content": tool_result,
121+
})
122+
123+
124+
@duron.op
125+
async def call_tool(params: ChatCompletionMessageToolCallUnion) -> tuple[str, str]: # noqa: RUF029
126+
if params.type != "function" or not params.function.name:
127+
return params.id, '{"status": "error", "message": "Invalid tool call"}'
128+
tool_name = params.function.name
129+
130+
if tool_name == "get_temperature":
131+
return params.id, get_temperature(
132+
TemperatureInput.model_validate_json(params.function.arguments or "{}")
133+
).model_dump_json()
134+
if tool_name == "get_forecast":
135+
return params.id, get_forecast(
136+
ForecastInput.model_validate_json(params.function.arguments or "{}")
137+
).model_dump_json()
138+
return params.id, '{"status": "error", "message": "Unknown tool"}'
71139

72140

73141
async def main() -> None:
@@ -80,7 +148,7 @@ async def main() -> None:
80148
)
81149
args = parser.parse_args()
82150

83-
log_storage = FileLogStorage(Path("logs") / f"{args.session_id}.json")
151+
log_storage = FileLogStorage(Path("logs") / f"{args.session_id}.jsonl")
84152
async with agent_fn.invoke(log_storage) as job:
85153
await job.start()
86154
await job.wait()
@@ -89,62 +157,116 @@ async def main() -> None:
89157
async def completion(
90158
ctx: duron.Context,
91159
messages: list[ChatCompletionMessageParam],
92-
) -> ParsedChatCompletionMessage[None]:
93-
@op(
94-
checkpoint=True,
95-
action_type=ChatCompletionChunk,
96-
return_type=ChatCompletionStreamState | None,
97-
initial=lambda: None,
98-
reducer=lambda a, b: (
99-
s := a or ChatCompletionStreamState(),
100-
s.handle_chunk(b),
101-
s,
102-
)[-1],
103-
)
104-
async def _completion_stream(
105-
prev: ChatCompletionStreamState | None,
160+
) -> ChatCompletion:
161+
@duron.op
162+
async def _completion(
106163
messages: list[ChatCompletionMessageParam],
107-
) -> AsyncGenerator[ChatCompletionChunk, ChatCompletionStreamState | None]:
108-
if prev:
109-
msg = prev.current_completion_snapshot.choices[0].message
110-
messages = [
111-
*messages,
112-
{
113-
"role": "assistant",
114-
"content": msg.content,
115-
"tool_calls": (
116-
{
117-
"id": call.id,
118-
"type": call.type,
119-
"function": {
120-
"name": call.function.name,
121-
"arguments": call.function.arguments,
122-
},
123-
}
124-
for call in msg.tool_calls
125-
)
126-
if msg.tool_calls
127-
else (),
128-
},
129-
]
164+
) -> ChatCompletion:
165+
state = ChatCompletionStreamState()
130166
async for chunk in await client.chat.completions.create(
131167
messages=messages,
168+
tools=[
169+
pydantic_function_tool(
170+
TemperatureInput,
171+
name="get_temperature",
172+
description="Get current temperature for a location",
173+
),
174+
pydantic_function_tool(
175+
ForecastInput,
176+
name="get_forecast",
177+
description="Get weather forecast for a location",
178+
),
179+
],
132180
model=DEFAULT_MODEL,
133181
stream=True,
134182
):
135183
if chunk.object: # type: ignore[redundant-expr]
136-
yield chunk
184+
_ = state.handle_chunk(chunk)
185+
return state.get_final_completion()
137186

138-
state = await ctx.run(
139-
_completion_stream,
187+
return await ctx.run(
188+
_completion,
140189
RunOptions(
141190
metadata={"type": "chat.completions.create"},
142191
),
143192
messages,
144193
)
145-
assert state # noqa: S101
146-
return state.get_final_completion().choices[0].message
194+
195+
196+
# tools
197+
198+
199+
class TemperatureInput(BaseModel):
200+
location: str = Field(..., description="Location to get weather for")
201+
unit: Literal["celsius", "fahrenheit"] = Field(
202+
default="celsius", description="Temperature unit"
203+
)
204+
205+
206+
class TemperatureOutput(BaseModel):
207+
location: str
208+
temperature: float | None
209+
unit: Literal["celsius", "fahrenheit"]
210+
status: Literal["success", "error"]
211+
message: str | None = None
212+
213+
214+
class ForecastInput(BaseModel):
215+
location: str = Field(..., description="Location for forecast")
216+
days: int = Field(default=3, ge=1, le=7, description="Number of days (1-7)")
217+
218+
219+
class ForecastDay(BaseModel):
220+
day: int
221+
high: float
222+
low: float
223+
humidity: int
224+
wind_speed: int
225+
226+
227+
class ForecastOutput(BaseModel):
228+
location: str
229+
forecast: list[ForecastDay]
230+
status: Literal["success", "error"]
231+
232+
233+
# Simplified tool implementations
234+
def get_temperature(input_data: TemperatureInput) -> TemperatureOutput:
235+
# Generate random temperature based on realistic ranges
236+
if input_data.unit == "celsius":
237+
temp = round(random.uniform(0, 37), 1)
238+
else: # fahrenheit
239+
temp = round(random.uniform(-4, 113), 1)
240+
241+
return TemperatureOutput(
242+
location=input_data.location,
243+
temperature=temp,
244+
unit=input_data.unit,
245+
status="success",
246+
)
247+
248+
249+
def get_forecast(input_data: ForecastInput) -> ForecastOutput:
250+
forecast: list[ForecastDay] = []
251+
for i in range(input_data.days):
252+
high = round(random.uniform(15, 35), 1)
253+
low = round(random.uniform(5, high - 5), 1) # Low is always less than high
254+
255+
forecast.append(
256+
ForecastDay(
257+
day=i + 1,
258+
high=high,
259+
low=low,
260+
humidity=random.randint(30, 90),
261+
wind_speed=random.randint(5, 25),
262+
)
263+
)
264+
265+
return ForecastOutput(
266+
location=input_data.location, forecast=forecast, status="success"
267+
)
147268

148269

149270
if __name__ == "__main__":
271+
_ = readline
150272
asyncio.run(main())

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ docs = [
3636
examples = [
3737
"pydantic>=2.11.9",
3838
"openai>=2.2.0",
39+
"rich>=14.1.0",
3940
]
4041

4142
[tool.ruff.lint]
@@ -51,7 +52,7 @@ future-annotations = true
5152

5253
[tool.ruff.lint.per-file-ignores]
5354
"tests/**/*.py" = ["S101", "S311"]
54-
"examples/**/*.py" = ["T201"]
55+
"examples/**/*.py" = ["T201", "S311"]
5556

5657
[tool.ruff.lint.isort]
5758
extra-standard-library = ["typing_extensions"]

src/duron/_core/context.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ async def run(
108108
await stream.discard()
109109
return await stream
110110

111-
return_type = (
112-
fn.return_type
113-
if isinstance(fn, Op) and fn.return_type
114-
else self._task.codec.inspect_function(
115-
cast("Callable[..., object]", fn),
116-
).return_type
117-
)
111+
if isinstance(fn, Op):
112+
if fn.return_type:
113+
return_type = fn.return_type
114+
else:
115+
return_type = self._task.codec.inspect_function(fn.fn).return_type
116+
else:
117+
return_type = self._task.codec.inspect_function(fn).return_type
118118

119119
metadata = options.metadata if options else None
120120
op = create_op(

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)