22
33import argparse
44import asyncio
5- from collections .abc import AsyncGenerator
5+ import os
6+ import random
7+ import readline
68from pathlib import Path
7- from typing import TYPE_CHECKING , cast
9+ from typing import TYPE_CHECKING , Literal , cast
810from typing_extensions import override
911
10- from openai import AsyncOpenAI
12+ from openai import AsyncOpenAI , pydantic_function_tool
1113from openai .lib .streaming .chat import ChatCompletionStreamState
1214from openai .types .chat import (
13- ChatCompletionChunk ,
15+ ChatCompletion ,
1416 ChatCompletionMessageParam ,
17+ ChatCompletionMessageToolCallUnion ,
1518)
16- from pydantic import TypeAdapter
19+ from pydantic import BaseModel , Field , TypeAdapter
20+ from rich .console import Console
1721
1822import duron
19- from duron import RunOptions , op
23+ from duron import RunOptions
2024from duron .codec import Codec
2125from duron .contrib .storage import FileLogStorage
2226
2327if TYPE_CHECKING :
2428 from typing import Any
2529
26- from openai .types .chat import (
27- ParsedChatCompletionMessage ,
28- )
29-
3030 from duron .codec import JSONValue
3131 from duron .typing import TypeHint
3232
@@ -52,22 +52,90 @@ def decode_json(self, encoded: JSONValue, expected_type: TypeHint[Any]) -> objec
5252 return cast ("object" , TypeAdapter (expected_type ).validate_python (encoded ))
5353
5454
55+ @duron .op
56+ async def do_input () -> str : # noqa: RUF029
57+ try :
58+ return input ("> " ) # noqa: ASYNC250
59+ except EOFError :
60+ os ._exit (0 )
61+ except KeyboardInterrupt :
62+ os ._exit (1 )
63+
64+
5565@duron .fn (codec = PydanticCodec ())
5666async def agent_fn (ctx : duron .Context ) -> None :
57- result = await completion (
58- ctx ,
59- messages = [
60- {
61- "role" : "system" ,
62- "content" : "You are a helpful assistant!" ,
63- },
64- {
65- "role" : "user" ,
66- "content" : "Say hello to Duron." ,
67- },
68- ],
69- )
70- print (result .content )
67+ console = Console ()
68+ history : list [ChatCompletionMessageParam ] = [
69+ {
70+ "role" : "system" ,
71+ "content" : "You are a helpful assistant!" ,
72+ },
73+ ]
74+ while True :
75+ msg = await ctx .run (do_input )
76+ history .append ({
77+ "role" : "user" ,
78+ "content" : msg ,
79+ })
80+ console .print ("[bold cyan] USER[/bold cyan]" , msg )
81+ while True :
82+ result = await completion (
83+ ctx ,
84+ messages = history ,
85+ )
86+ if result .choices [0 ].message .content :
87+ console .print (
88+ "[bold red]ASSISTANT[/bold red] " , result .choices [0 ].message .content
89+ )
90+ history .append ({
91+ "role" : "assistant" ,
92+ "content" : result .choices [0 ].message .content ,
93+ "tool_calls" : [
94+ {
95+ "id" : toolcall .id ,
96+ "type" : "function" ,
97+ "function" : {
98+ "name" : toolcall .function .name ,
99+ "arguments" : toolcall .function .arguments ,
100+ },
101+ }
102+ for toolcall in result .choices [0 ].message .tool_calls or []
103+ if toolcall .type == "function"
104+ ],
105+ })
106+ if not result .choices [0 ].message .tool_calls :
107+ break
108+
109+ tasks : list [asyncio .Task [tuple [str , str ]]] = []
110+ for tool_call in result .choices [0 ].message .tool_calls :
111+ console .print ("[bold yellow] CALL[/bold yellow]" , tool_call .id )
112+ console .print (tool_call .model_dump_json ())
113+ tasks .append (asyncio .create_task (ctx .run (call_tool , None , tool_call )))
114+ for id_ , tool_result in await asyncio .gather (* tasks ):
115+ console .print ("[bold cyan] TOOL[/bold cyan]" , id_ )
116+ console .print (tool_result )
117+ history .append ({
118+ "role" : "tool" ,
119+ "tool_call_id" : id_ ,
120+ "content" : tool_result ,
121+ })
122+
123+
124+ @duron .op
125+ async def call_tool (params : ChatCompletionMessageToolCallUnion ) -> tuple [str , str ]: # noqa: RUF029
126+ if params .type != "function" or not params .function .name :
127+ return params .id , '{"status": "error", "message": "Invalid tool call"}'
128+ tool_name = params .function .name
129+
130+ if tool_name == "get_temperature" :
131+ return params .id , get_temperature (
132+ TemperatureInput .model_validate_json (params .function .arguments or "{}" )
133+ ).model_dump_json ()
134+ if tool_name == "get_forecast" :
135+ return params .id , get_forecast (
136+ ForecastInput .model_validate_json (params .function .arguments or "{}" )
137+ ).model_dump_json ()
138+ return params .id , '{"status": "error", "message": "Unknown tool"}'
71139
72140
73141async def main () -> None :
@@ -80,7 +148,7 @@ async def main() -> None:
80148 )
81149 args = parser .parse_args ()
82150
83- log_storage = FileLogStorage (Path ("logs" ) / f"{ args .session_id } .json " )
151+ log_storage = FileLogStorage (Path ("logs" ) / f"{ args .session_id } .jsonl " )
84152 async with agent_fn .invoke (log_storage ) as job :
85153 await job .start ()
86154 await job .wait ()
@@ -89,62 +157,116 @@ async def main() -> None:
89157async def completion (
90158 ctx : duron .Context ,
91159 messages : list [ChatCompletionMessageParam ],
92- ) -> ParsedChatCompletionMessage [None ]:
93- @op (
94- checkpoint = True ,
95- action_type = ChatCompletionChunk ,
96- return_type = ChatCompletionStreamState | None ,
97- initial = lambda : None ,
98- reducer = lambda a , b : (
99- s := a or ChatCompletionStreamState (),
100- s .handle_chunk (b ),
101- s ,
102- )[- 1 ],
103- )
104- async def _completion_stream (
105- prev : ChatCompletionStreamState | None ,
160+ ) -> ChatCompletion :
161+ @duron .op
162+ async def _completion (
106163 messages : list [ChatCompletionMessageParam ],
107- ) -> AsyncGenerator [ChatCompletionChunk , ChatCompletionStreamState | None ]:
108- if prev :
109- msg = prev .current_completion_snapshot .choices [0 ].message
110- messages = [
111- * messages ,
112- {
113- "role" : "assistant" ,
114- "content" : msg .content ,
115- "tool_calls" : (
116- {
117- "id" : call .id ,
118- "type" : call .type ,
119- "function" : {
120- "name" : call .function .name ,
121- "arguments" : call .function .arguments ,
122- },
123- }
124- for call in msg .tool_calls
125- )
126- if msg .tool_calls
127- else (),
128- },
129- ]
164+ ) -> ChatCompletion :
165+ state = ChatCompletionStreamState ()
130166 async for chunk in await client .chat .completions .create (
131167 messages = messages ,
168+ tools = [
169+ pydantic_function_tool (
170+ TemperatureInput ,
171+ name = "get_temperature" ,
172+ description = "Get current temperature for a location" ,
173+ ),
174+ pydantic_function_tool (
175+ ForecastInput ,
176+ name = "get_forecast" ,
177+ description = "Get weather forecast for a location" ,
178+ ),
179+ ],
132180 model = DEFAULT_MODEL ,
133181 stream = True ,
134182 ):
135183 if chunk .object : # type: ignore[redundant-expr]
136- yield chunk
184+ _ = state .handle_chunk (chunk )
185+ return state .get_final_completion ()
137186
138- state = await ctx .run (
139- _completion_stream ,
187+ return await ctx .run (
188+ _completion ,
140189 RunOptions (
141190 metadata = {"type" : "chat.completions.create" },
142191 ),
143192 messages ,
144193 )
145- assert state # noqa: S101
146- return state .get_final_completion ().choices [0 ].message
194+
195+
196+ # tools
197+
198+
199+ class TemperatureInput (BaseModel ):
200+ location : str = Field (..., description = "Location to get weather for" )
201+ unit : Literal ["celsius" , "fahrenheit" ] = Field (
202+ default = "celsius" , description = "Temperature unit"
203+ )
204+
205+
206+ class TemperatureOutput (BaseModel ):
207+ location : str
208+ temperature : float | None
209+ unit : Literal ["celsius" , "fahrenheit" ]
210+ status : Literal ["success" , "error" ]
211+ message : str | None = None
212+
213+
214+ class ForecastInput (BaseModel ):
215+ location : str = Field (..., description = "Location for forecast" )
216+ days : int = Field (default = 3 , ge = 1 , le = 7 , description = "Number of days (1-7)" )
217+
218+
219+ class ForecastDay (BaseModel ):
220+ day : int
221+ high : float
222+ low : float
223+ humidity : int
224+ wind_speed : int
225+
226+
227+ class ForecastOutput (BaseModel ):
228+ location : str
229+ forecast : list [ForecastDay ]
230+ status : Literal ["success" , "error" ]
231+
232+
233+ # Simplified tool implementations
234+ def get_temperature (input_data : TemperatureInput ) -> TemperatureOutput :
235+ # Generate random temperature based on realistic ranges
236+ if input_data .unit == "celsius" :
237+ temp = round (random .uniform (0 , 37 ), 1 )
238+ else : # fahrenheit
239+ temp = round (random .uniform (- 4 , 113 ), 1 )
240+
241+ return TemperatureOutput (
242+ location = input_data .location ,
243+ temperature = temp ,
244+ unit = input_data .unit ,
245+ status = "success" ,
246+ )
247+
248+
249+ def get_forecast (input_data : ForecastInput ) -> ForecastOutput :
250+ forecast : list [ForecastDay ] = []
251+ for i in range (input_data .days ):
252+ high = round (random .uniform (15 , 35 ), 1 )
253+ low = round (random .uniform (5 , high - 5 ), 1 ) # Low is always less than high
254+
255+ forecast .append (
256+ ForecastDay (
257+ day = i + 1 ,
258+ high = high ,
259+ low = low ,
260+ humidity = random .randint (30 , 90 ),
261+ wind_speed = random .randint (5 , 25 ),
262+ )
263+ )
264+
265+ return ForecastOutput (
266+ location = input_data .location , forecast = forecast , status = "success"
267+ )
147268
148269
149270if __name__ == "__main__" :
271+ _ = readline
150272 asyncio .run (main ())
0 commit comments