1010from fastapi .responses import HTMLResponse
1111from fastapi .templating import Jinja2Templates
1212
13- from featuremanagement .azuremonitor import track_event
14-
1513import uuid
1614import pathlib
1715from azure .ai .inference .prompts import PromptTemplate
3129)
3230
3331from opentelemetry .baggage import get_baggage
34- from azure .ai .evaluation import CoherenceEvaluator , FluencyEvaluator , RelevanceEvaluator , ViolenceEvaluator , SexualEvaluator , HateUnfairnessEvaluator , ProtectedMaterialEvaluator , ContentSafetyEvaluator
35- import asyncio
3632from opentelemetry .baggage import set_baggage , get_baggage
3733from opentelemetry .context import attach
3834from featuremanagement import TargetingContext , FeatureManager
39- from azure .identity import DefaultAzureCredential
35+ from azure .appconfiguration .provider import AzureAppConfigurationProvider
36+
37+ # import asyncio
38+ # from azure.ai.evaluation import CoherenceEvaluator, FluencyEvaluator, RelevanceEvaluator, ViolenceEvaluator, SexualEvaluator, HateUnfairnessEvaluator, ProtectedMaterialEvaluator, ContentSafetyEvaluator
39+ # from azure.identity import DefaultAzureCredential
4040
4141router = fastapi .APIRouter ()
4242templates = Jinja2Templates (directory = "api/templates" )
@@ -51,9 +51,12 @@ def get_chat_model(request: Request) -> str:
5151def get_search_index_namager (request : Request ) -> SearchIndexManager :
5252 return request .app .state .search_index_manager
5353
54- def get_feature_manager (request : Request ) -> str :
54+ def get_feature_manager (request : Request ) -> FeatureManager :
5555 return request .app .state .feature_manager
5656
57+ def get_app_config (request : Request ) -> AzureAppConfigurationProvider :
58+ return request .app .state .app_config
59+
5760class Message (pydantic .BaseModel ):
5861 content : str
5962 role : str = "user"
@@ -78,16 +81,20 @@ async def chat_stream_handler(
7881 chat_client : ChatCompletionsClient = Depends (get_chat_client ),
7982 model_deployment_name : str = Depends (get_chat_model ),
8083 search_index_manager : SearchIndexManager = Depends (get_search_index_namager ),
81- feature_manager : FeatureManager = Depends (get_feature_manager )
84+ feature_manager : FeatureManager = Depends (get_feature_manager ),
85+ app_config : AzureAppConfigurationProvider = Depends (get_app_config ),
8286) -> fastapi .responses .StreamingResponse :
8387 if chat_client is None :
8488 raise Exception ("Chat client not initialized" )
8589
8690 async def response_stream ():
8791 messages = [{"role" : message .role , "content" : message .content } for message in chat_request .messages ]
8892
89- targeting_id = chat_request .sessionState .get ('sessionId' , str (uuid .uuid4 ()))
90- attach (set_baggage ("Microsoft.TargetingId" , targeting_id ))
93+ # Refresh config and set targeting context for analysis
94+ if app_config and feature_manager :
95+ app_config .refresh ()
96+ targeting_id = chat_request .sessionState .get ('sessionId' , str (uuid .uuid4 ()))
97+ attach (set_baggage ("Microsoft.TargetingId" , targeting_id ))
9198
9299 # figure out which prompty template to use
93100 prompt_template = "prompt.v1.prompty"
@@ -172,54 +179,69 @@ async def response_stream():
172179 + "\n "
173180 )
174181
182+ # TODO: add variant to response
183+
175184 return fastapi .responses .StreamingResponse (response_stream ())
176185
177186
178187def get_targeting_context () -> TargetingContext :
179188 return TargetingContext (user_id = get_baggage ("Microsoft.TargetingId" ))
180189
181- # @router.post("/chat")
182- # async def chat_nostream_handler(
183- # chat_request: ChatRequest,
184- # request: Request
185- # ):
186- # chat_client = globals["chat"]
187- # if chat_client is None:
188- # raise Exception("Chat client not initialized")
190+ @router .post ("/chat" )
191+ async def chat_nostream_handler (
192+ chat_request : ChatRequest ,
193+ chat_client : ChatCompletionsClient = Depends (get_chat_client ),
194+ model_deployment_name : str = Depends (get_chat_model ),
195+ search_index_manager : SearchIndexManager = Depends (get_search_index_namager ),
196+ feature_manager : FeatureManager = Depends (get_feature_manager ),
197+ app_config : AzureAppConfigurationProvider = Depends (get_app_config ),
198+ ):
199+ messages = [{"role" : message .role , "content" : message .content } for message in chat_request .messages ]
189200
190- # messages = [{"role": message.role, "content": message.content} for message in chat_request.messages]
191- # model_deployment_name = globals["chat_model"]
192- # feature_manager = globals["feature_manager"]
193-
194- # targeting_id = chat_request.sessionState.get('sessionId', str(uuid.uuid4()))
195- # attach(set_baggage("Microsoft.TargetingId", targeting_id))
201+ # Refresh config and set targeting context for analysis
202+ if app_config and feature_manager :
203+ app_config .refresh ()
204+ targeting_id = chat_request .sessionState .get ('sessionId' , str (uuid .uuid4 ()))
205+ attach (set_baggage ("Microsoft.TargetingId" , targeting_id ))
196206
197- # # figure out which prompty template to use (replace file to API)
198- # variant = "none"
199- # if chat_request.prompt_override:
200- # prompt = PromptTemplate.from_prompty(pathlib.Path(__file__).parent.resolve() / chat_request.prompt_override)
201- # variant = chat_request.prompt_override
202- # else:
203- # prompt_variant = feature_manager.get_variant("prompty_file") # replace this with prompt_asset
204- # if prompt_variant and prompt_variant.configuration:
205- # prompt = PromptTemplate.from_prompty(pathlib.Path(__file__).parent.resolve() / prompt_variant.configuration)
206- # variant = prompt_variant.name
207- # else:
208- # prompt = globals["prompt"]
209-
210- # prompt_messages = prompt.create_messages()
211-
212- # try:
213- # response = await chat_client.complete(
214- # model=model_deployment_name, messages=prompt_messages + messages, stream=False
215- # )
216- # track_event("RequestMade", targeting_id)
217- # answer = response.choices[0].message.content
218- # except Exception as e:
219- # error = {"Error": str(e)}
220- # track_event("ErrorLLM", targeting_id, error)
221- # return { "answer": str(e), "variant": variant }
207+ # figure out which prompty template to use
208+ prompt_template = "prompt.v1.prompty"
209+ if chat_request .prompt_override :
210+ prompt_template = chat_request .prompt_override
211+ elif feature_manager is not None :
212+ prompt_variant = feature_manager .get_variant ("prompty_file" ) # replace this with prompt_asset
213+ if prompt_variant and prompt_variant .configuration : # TODO: check file exists
214+ prompt_template = prompt_variant .configuration
215+
216+ prompt = PromptTemplate .from_prompty (pathlib .Path (__file__ ).parent .resolve () / prompt_template )
217+ prompt_messages = prompt .create_messages ()
218+
219+ # Use RAG model, only if we were provided index and we have found a context there.
220+ if search_index_manager is not None :
221+ context = await search_index_manager .search (chat_request )
222+ if context :
223+ prompt_messages = PromptTemplate .from_string (
224+ 'You are a helpful assistant that answers some questions '
225+ 'with the help of some context data.\n \n Here is '
226+ 'the context data:\n \n {{context}}' ).create_messages (data = dict (context = context ))
227+ logger .info (f"{ prompt_messages = } " )
228+ else :
229+ logger .info ("Unable to find the relevant information in the index for the request." )
230+
231+ try :
232+ response = await chat_client .complete (
233+ model = model_deployment_name , messages = prompt_messages + messages , stream = False
234+ )
235+ answer = response .choices [0 ].message .content
236+ except Exception as e :
237+ error = {"Error" : str (e )}
238+ #track_event("ErrorLLM", targeting_id, error)
239+ answer = error
240+
241+ return { "answer" : answer , "variant" : prompt_variant .name if prompt_variant else None }
242+
222243
244+ # Inline Evaluation Prototype
223245
224246 # conversation = {}
225247
@@ -247,28 +269,28 @@ def get_targeting_context() -> TargetingContext:
247269
248270 # asyncio.create_task(run_evals(eval_input, targeting_id, project.scope, DefaultAzureCredential()))
249271
250- return { "answer" : answer , "variant" : variant }
272+ # return { "answer": answer, "variant": variant }
251273
252274
253- async def run_evals (eval_input , targeting_id , ai_project_scope , credential ):
254- run_eval (FluencyEvaluator , eval_input , targeting_id )
255- run_eval (RelevanceEvaluator , eval_input , targeting_id )
256- run_eval (CoherenceEvaluator , eval_input , targeting_id )
257-
258- run_safety_eval (ViolenceEvaluator , eval_input , targeting_id , ai_project_scope , credential )
259- run_safety_eval (SexualEvaluator , eval_input , targeting_id , ai_project_scope , credential )
260- run_safety_eval (HateUnfairnessEvaluator , eval_input , targeting_id , ai_project_scope , credential )
261- run_safety_eval (ProtectedMaterialEvaluator , eval_input , targeting_id , ai_project_scope , credential )
262- run_safety_eval (ContentSafetyEvaluator , eval_input , targeting_id , ai_project_scope , credential )
263-
264- def run_safety_eval (evaluator , eval_input , targeting_id , ai_project_scope , credential ):
265- eval = evaluator (credential = credential , azure_ai_project = ai_project_scope )
266- score = eval (** eval_input )
267- score .update ({"evaluator_id" : eval .id })
268- track_event ("gen.ai." + type (eval ).__name__ , targeting_id , score )
269-
270- def run_eval (evaluator , eval_input , targeting_id ):
271- eval = evaluator (globals ["model_config" ])
272- score = eval (** eval_input )
273- score .update ({"evaluator_id" : evaluator .id })
274- track_event ("gen.ai." + evaluator .__name__ , targeting_id , score )
275+ # async def run_evals(eval_input, targeting_id, ai_project_scope, credential):
276+ # run_eval(FluencyEvaluator, eval_input, targeting_id)
277+ # run_eval(RelevanceEvaluator, eval_input, targeting_id)
278+ # run_eval(CoherenceEvaluator, eval_input, targeting_id)
279+
280+ # run_safety_eval(ViolenceEvaluator, eval_input, targeting_id, ai_project_scope, credential)
281+ # run_safety_eval(SexualEvaluator, eval_input, targeting_id, ai_project_scope, credential)
282+ # run_safety_eval(HateUnfairnessEvaluator, eval_input, targeting_id, ai_project_scope, credential)
283+ # run_safety_eval(ProtectedMaterialEvaluator, eval_input, targeting_id, ai_project_scope, credential)
284+ # run_safety_eval(ContentSafetyEvaluator, eval_input, targeting_id, ai_project_scope, credential)
285+
286+ # def run_safety_eval(evaluator, eval_input, targeting_id, ai_project_scope, credential):
287+ # eval = evaluator(credential=credential, azure_ai_project=ai_project_scope)
288+ # score = eval(**eval_input)
289+ # score.update({"evaluator_id": eval.id})
290+ # track_event("gen.ai." + type(eval).__name__, targeting_id, score)
291+
292+ # def run_eval(evaluator, eval_input, targeting_id):
293+ # eval = evaluator(globals["model_config"])
294+ # score = eval(**eval_input)
295+ # score.update({"evaluator_id": evaluator.id})
296+ # track_event("gen.ai." + evaluator.__name__, targeting_id, score)
0 commit comments