11# Copyright (c) Microsoft. All rights reserved.
22# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
33
4- from typing import Any
4+ from typing import AsyncGenerator , Dict , Optional , Tuple
55from quart import Blueprint , jsonify , request , Response , render_template , current_app
66
77import asyncio
1212from azure .identity import DefaultAzureCredential
1313
1414from azure .ai .projects .models import (
15- MessageDeltaTextContent ,
1615 MessageDeltaChunk ,
1716 ThreadMessage ,
1817 FileSearchTool ,
1918 AsyncToolSet ,
2019 FilePurpose ,
21- AgentStreamEvent
20+ ThreadMessage ,
21+ StreamEventData ,
22+ AsyncAgentEventHandler ,
23+ Agent ,
24+ VectorStore
2225)
2326
24- bp = Blueprint ("chat" , __name__ , template_folder = "templates" , static_folder = "static" )
27+ class ChatBlueprint (Blueprint ):
28+ ai_client : AIProjectClient
29+ agent : Agent
30+ files : Dict [str , str ]
31+ vector_store : VectorStore
32+
33+ bp = ChatBlueprint ("chat" , __name__ , template_folder = "templates" , static_folder = "static" )
34+
35+ class MyEventHandler (AsyncAgentEventHandler [str ]):
36+
37+ async def on_message_delta (
38+ self , delta : "MessageDeltaChunk"
39+ ) -> Optional [str ]:
40+ stream_data = json .dumps ({'content' : delta .text , 'type' : "message" })
41+ return f"data: { stream_data } \n \n "
42+
43+ async def on_thread_message (
44+ self , message : "ThreadMessage"
45+ ) -> Optional [str ]:
46+ if message .status == "completed" :
47+ annotations = [annotation .as_dict () for annotation in message .file_citation_annotations ]
48+ stream_data = json .dumps ({'content' : message .text_messages [0 ].text .value , 'annotations' : annotations , 'type' : "completed_message" })
49+ return f"data: { stream_data } \n \n "
50+ return None
51+
52+ async def on_error (self , data : str ) -> Optional [str ]:
53+ print (f"An error occurred. Data: { data } " )
54+ stream_data = json .dumps ({'type' : "stream_end" })
55+ return f"data: { stream_data } \n \n "
56+
57+ async def on_done (
58+ self ,
59+ ) -> Optional [str ]:
60+ stream_data = json .dumps ({'type' : "stream_end" })
61+ return f"data: { stream_data } \n \n "
62+
2563
2664
2765@bp .before_app_serving
@@ -33,15 +71,15 @@ async def start_server():
3371 )
3472
3573 # TODO: add more files are not supported for citation at the moment
36- files = ["product_info_1.md" ]
37- file_ids = []
38- for file in files :
39- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , file ))
74+ file_names = ["product_info_1.md" , "product_info_2 .md" ]
75+ files : Dict [ str , str ] = {}
76+ for file_name in file_names :
77+ file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , file_name ))
4078 print (f"Uploading file { file_path } " )
4179 file = await ai_client .agents .upload_file_and_poll (file_path = file_path , purpose = FilePurpose .AGENTS )
42- file_ids . append ( file .id )
80+ files . update ({ file .id : file_path } )
4381
44- vector_store = await ai_client .agents .create_vector_store (file_ids = file_ids , name = "sample_store" )
82+ vector_store = await ai_client .agents .create_vector_store_and_poll (file_ids = list ( files . keys ()) , name = "sample_store" )
4583
4684 file_search_tool = FileSearchTool (vector_store_ids = [vector_store .id ])
4785
@@ -59,12 +97,12 @@ async def start_server():
5997 bp .ai_client = ai_client
6098 bp .agent = agent
6199 bp .vector_store = vector_store
62- bp .file_ids = file_ids
100+ bp .files = files
63101
64102
65103@bp .after_app_serving
66104async def stop_server ():
67- for file_id in bp .file_ids :
105+ for file_id in bp .files . keys () :
68106 await bp .ai_client .agents .delete_file (file_id )
69107 print (f"Deleted file { file_id } " )
70108
@@ -78,47 +116,32 @@ async def stop_server():
78116 await bp .ai_client .close ()
79117 print ("Closed AIProjectClient" )
80118
119+
120+
121+
81122@bp .get ("/" )
82123async def index ():
83124 return await render_template ("index.html" )
84125
85- async def create_stream (thread_id : str , agent_id : str ):
126+
127+
128+ async def get_result (thread_id : str , agent_id : str ) -> AsyncGenerator [str , None ]:
86129 async with await bp .ai_client .agents .create_stream (
87- thread_id = thread_id , assistant_id = agent_id
130+ thread_id = thread_id , assistant_id = agent_id ,
131+ event_handler = MyEventHandler ()
88132 ) as stream :
89- accumulated_text = ""
90-
91- async for event_type , event_data in stream :
92-
93- stream_data = None
94- if isinstance (event_data , MessageDeltaChunk ):
95- for content_part in event_data .delta .content :
96- if isinstance (content_part , MessageDeltaTextContent ):
97- text_value = content_part .text .value if content_part .text else "No text"
98- accumulated_text += text_value
99- print (f"Text delta received: { text_value } " )
100- stream_data = json .dumps ({'content' : text_value , 'type' : "message" })
101-
102- elif isinstance (event_data , ThreadMessage ):
103- print (f"ThreadMessage created. ID: { event_data .id } , Status: { event_data .status } " )
104- if (event_data .status == "completed" ):
105- stream_data = json .dumps ({'content' : accumulated_text , 'type' : "completed_message" })
106-
107- elif event_type == AgentStreamEvent .DONE :
108- print ("Stream completed." )
109- stream_data = json .dumps ({'type' : "stream_end" })
110-
111- if stream_data :
112- yield f"data: { stream_data } \n \n "
133+ # Iterate over the steam to trigger event functions
134+ async for _ , _ , event_func_return_val in stream :
135+ if event_func_return_val :
136+ yield event_func_return_val
113137
114-
115138@bp .route ('/chat' , methods = ['POST' ])
116139async def chat ():
117140 thread_id = request .cookies .get ('thread_id' )
118141 agent_id = request .cookies .get ('agent_id' )
119142 thread = None
120143
121- if thread_id or agent_id ! = bp .agent .id :
144+ if thread_id and agent_id = = bp .agent .id :
122145 # Check if the thread is still active
123146 try :
124147 thread = await bp .ai_client .agents .get_thread (thread_id )
@@ -147,24 +170,21 @@ async def chat():
147170 'Content-Type' : 'text/event-stream'
148171 }
149172
150- response = Response (create_stream (thread_id , agent_id ), headers = headers )
173+ response = Response (get_result (thread_id , agent_id ), headers = headers )
151174 response .set_cookie ('thread_id' , thread_id )
152175 response .set_cookie ('agent_id' , agent_id )
153176 return response
154177
155178@bp .route ('/fetch-document' , methods = ['GET' ])
156179async def fetch_document ():
157- filename = "product_info_1.md"
158-
159- # Get the file path from the mapping
160- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , filename ))
161-
162- if not os .path .exists (file_path ):
163- return jsonify ({"error" : f"File not found: { filename } " }), 404
180+ file_id = request .args .get ('file_id' )
181+ current_app .logger .info (f"Fetching document: { file_id } " )
182+ if not file_id :
183+ return jsonify ({"error" : "file_id is required" }), 400
164184
165185 try :
166186 # Read the file content asynchronously using asyncio.to_thread
167- data = await asyncio .to_thread (read_file , file_path )
187+ data = await asyncio .to_thread (read_file , bp . files [ file_id ] )
168188 return Response (data , content_type = 'text/plain' )
169189
170190 except Exception as e :
0 commit comments