Skip to content

Commit 9283f1f

Browse files
committed
cleanup and switch to no stream
1 parent 001b0f2 commit 9283f1f

3 files changed

Lines changed: 149 additions & 120 deletions

File tree

src/api/main.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .routes import get_targeting_context
1818

1919
from azure.identity import DefaultAzureCredential
20-
from azure.appconfiguration.provider import load
2120

2221
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
2322

@@ -60,24 +59,29 @@ async def lifespan(app: fastapi.FastAPI):
6059
logger.error("Enable it via the 'Tracing' tab in your AI Foundry project page.")
6160
exit()
6261
else:
63-
from azure.monitor.opentelemetry import configure_azure_monitor
64-
from featuremanagement import FeatureManager
65-
from featuremanagement.azuremonitor import publish_telemetry, TargetingSpanProcessor
66-
67-
configure_azure_monitor(connection_string=application_insights_connection_string, span_processors=[TargetingSpanProcessor(targeting_context_accessor=get_targeting_context)])
68-
69-
# Inititalize the feature manager / TODO: Add null check
70-
app_config_conn_str = os.getenv("APP_CONFIGURATION_ENDPOINT") # this will become: project.experiments.get_connection_string()
71-
72-
app_config = load(
73-
endpoint=app_config_conn_str,
74-
credential=DefaultAzureCredential(),
75-
feature_flag_enabled=True,
76-
feature_flag_refresh_enabled=True,
77-
refresh_interval=30, # 30 seconds
78-
)
79-
feature_manager = FeatureManager(app_config, targeting_context_accessor=get_targeting_context, on_feature_evaluated=publish_telemetry)
80-
app.state.feature_manager = feature_manager
62+
from azure.monitor.opentelemetry import configure_azure_monitor
63+
app_config_conn_str = os.getenv("APP_CONFIGURATION_ENDPOINT")
64+
if app_config_conn_str:
65+
from azure.appconfiguration.provider import load
66+
from featuremanagement import FeatureManager
67+
from featuremanagement.azuremonitor import publish_telemetry, TargetingSpanProcessor
68+
logger.info("Configured Application Insights with App Configuration feature flag support")
69+
configure_azure_monitor(
70+
connection_string=application_insights_connection_string,
71+
span_processors=[TargetingSpanProcessor(targeting_context_accessor=get_targeting_context)])
72+
app_config = load(
73+
endpoint=app_config_conn_str,
74+
credential=DefaultAzureCredential(),
75+
feature_flag_enabled=True,
76+
feature_flag_refresh_enabled=True,
77+
refresh_interval=30, # 30 seconds
78+
)
79+
feature_manager = FeatureManager(app_config, targeting_context_accessor=get_targeting_context, on_feature_evaluated=publish_telemetry)
80+
app.state.app_config = app_config
81+
app.state.feature_manager = feature_manager
82+
else:
83+
logger.info("Configured Application Insights.")
84+
configure_azure_monitor(connection_string=application_insights_connection_string)
8185

8286
chat = await project.inference.get_chat_completions_client()
8387
embed = await project.inference.get_embeddings_client()
@@ -106,8 +110,7 @@ async def lifespan(app: fastapi.FastAPI):
106110

107111
app.state.chat = chat
108112
app.state.search_index_manager = search_index_manager
109-
app.state.chat_model = os.environ["AZURE_AI_CHAT_DEPLOYMENT_NAME"]
110-
113+
app.state.chat_model = os.environ["AZURE_AI_CHAT_DEPLOYMENT_NAME"]
111114

112115
yield
113116

@@ -166,6 +169,8 @@ def create_app():
166169
else:
167170
logger.info("Tracing is not enabled")
168171

172+
# TODO: enable_app_config and make sure libaries are installed
173+
169174
app = fastapi.FastAPI(lifespan=lifespan)
170175

171176
static_dir = os.path.join(os.path.dirname(__file__), "static")

src/api/routes.py

Lines changed: 93 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from fastapi.responses import HTMLResponse
1111
from fastapi.templating import Jinja2Templates
1212

13-
from featuremanagement.azuremonitor import track_event
14-
1513
import uuid
1614
import pathlib
1715
from azure.ai.inference.prompts import PromptTemplate
@@ -31,12 +29,14 @@
3129
)
3230

3331
from opentelemetry.baggage import get_baggage
34-
from azure.ai.evaluation import CoherenceEvaluator, FluencyEvaluator, RelevanceEvaluator, ViolenceEvaluator, SexualEvaluator, HateUnfairnessEvaluator, ProtectedMaterialEvaluator, ContentSafetyEvaluator
35-
import asyncio
3632
from opentelemetry.baggage import set_baggage, get_baggage
3733
from opentelemetry.context import attach
3834
from featuremanagement import TargetingContext, FeatureManager
39-
from azure.identity import DefaultAzureCredential
35+
from azure.appconfiguration.provider import AzureAppConfigurationProvider
36+
37+
# import asyncio
38+
# from azure.ai.evaluation import CoherenceEvaluator, FluencyEvaluator, RelevanceEvaluator, ViolenceEvaluator, SexualEvaluator, HateUnfairnessEvaluator, ProtectedMaterialEvaluator, ContentSafetyEvaluator
39+
# from azure.identity import DefaultAzureCredential
4040

4141
router = fastapi.APIRouter()
4242
templates = Jinja2Templates(directory="api/templates")
@@ -51,9 +51,12 @@ def get_chat_model(request: Request) -> str:
5151
def get_search_index_namager(request: Request) -> SearchIndexManager:
5252
return request.app.state.search_index_manager
5353

54-
def get_feature_manager(request: Request) -> str:
54+
def get_feature_manager(request: Request) -> FeatureManager:
5555
return request.app.state.feature_manager
5656

57+
def get_app_config(request: Request) -> AzureAppConfigurationProvider:
58+
return request.app.state.app_config
59+
5760
class Message(pydantic.BaseModel):
5861
content: str
5962
role: str = "user"
@@ -78,16 +81,20 @@ async def chat_stream_handler(
7881
chat_client: ChatCompletionsClient = Depends(get_chat_client),
7982
model_deployment_name: str = Depends(get_chat_model),
8083
search_index_manager: SearchIndexManager = Depends(get_search_index_namager),
81-
feature_manager: FeatureManager = Depends(get_feature_manager)
84+
feature_manager: FeatureManager = Depends(get_feature_manager),
85+
app_config: AzureAppConfigurationProvider = Depends(get_app_config),
8286
) -> fastapi.responses.StreamingResponse:
8387
if chat_client is None:
8488
raise Exception("Chat client not initialized")
8589

8690
async def response_stream():
8791
messages = [{"role": message.role, "content": message.content} for message in chat_request.messages]
8892

89-
targeting_id = chat_request.sessionState.get('sessionId', str(uuid.uuid4()))
90-
attach(set_baggage("Microsoft.TargetingId", targeting_id))
93+
# Refresh config and set targeting context for analysis
94+
if app_config and feature_manager:
95+
app_config.refresh()
96+
targeting_id = chat_request.sessionState.get('sessionId', str(uuid.uuid4()))
97+
attach(set_baggage("Microsoft.TargetingId", targeting_id))
9198

9299
# figure out which prompty template to use
93100
prompt_template = "prompt.v1.prompty"
@@ -172,54 +179,69 @@ async def response_stream():
172179
+ "\n"
173180
)
174181

182+
# TODO: add variant to response
183+
175184
return fastapi.responses.StreamingResponse(response_stream())
176185

177186

178187
def get_targeting_context() -> TargetingContext:
179188
return TargetingContext(user_id=get_baggage("Microsoft.TargetingId"))
180189

181-
# @router.post("/chat")
182-
# async def chat_nostream_handler(
183-
# chat_request: ChatRequest,
184-
# request: Request
185-
# ):
186-
# chat_client = globals["chat"]
187-
# if chat_client is None:
188-
# raise Exception("Chat client not initialized")
190+
@router.post("/chat")
191+
async def chat_nostream_handler(
192+
chat_request: ChatRequest,
193+
chat_client: ChatCompletionsClient = Depends(get_chat_client),
194+
model_deployment_name: str = Depends(get_chat_model),
195+
search_index_manager: SearchIndexManager = Depends(get_search_index_namager),
196+
feature_manager: FeatureManager = Depends(get_feature_manager),
197+
app_config: AzureAppConfigurationProvider = Depends(get_app_config),
198+
):
199+
messages = [{"role": message.role, "content": message.content} for message in chat_request.messages]
189200

190-
# messages = [{"role": message.role, "content": message.content} for message in chat_request.messages]
191-
# model_deployment_name = globals["chat_model"]
192-
# feature_manager = globals["feature_manager"]
193-
194-
# targeting_id = chat_request.sessionState.get('sessionId', str(uuid.uuid4()))
195-
# attach(set_baggage("Microsoft.TargetingId", targeting_id))
201+
# Refresh config and set targeting context for analysis
202+
if app_config and feature_manager:
203+
app_config.refresh()
204+
targeting_id = chat_request.sessionState.get('sessionId', str(uuid.uuid4()))
205+
attach(set_baggage("Microsoft.TargetingId", targeting_id))
196206

197-
# # figure out which prompty template to use (replace file to API)
198-
# variant = "none"
199-
# if chat_request.prompt_override:
200-
# prompt = PromptTemplate.from_prompty(pathlib.Path(__file__).parent.resolve() / chat_request.prompt_override)
201-
# variant = chat_request.prompt_override
202-
# else:
203-
# prompt_variant = feature_manager.get_variant("prompty_file") # replace this with prompt_asset
204-
# if prompt_variant and prompt_variant.configuration:
205-
# prompt = PromptTemplate.from_prompty(pathlib.Path(__file__).parent.resolve() / prompt_variant.configuration)
206-
# variant = prompt_variant.name
207-
# else:
208-
# prompt = globals["prompt"]
209-
210-
# prompt_messages = prompt.create_messages()
211-
212-
# try:
213-
# response = await chat_client.complete(
214-
# model=model_deployment_name, messages=prompt_messages + messages, stream=False
215-
# )
216-
# track_event("RequestMade", targeting_id)
217-
# answer = response.choices[0].message.content
218-
# except Exception as e:
219-
# error = {"Error": str(e)}
220-
# track_event("ErrorLLM", targeting_id, error)
221-
# return { "answer": str(e), "variant": variant }
207+
# figure out which prompty template to use
208+
prompt_template = "prompt.v1.prompty"
209+
if chat_request.prompt_override:
210+
prompt_template = chat_request.prompt_override
211+
elif feature_manager is not None:
212+
prompt_variant = feature_manager.get_variant("prompty_file") # replace this with prompt_asset
213+
if prompt_variant and prompt_variant.configuration: # TODO: check file exists
214+
prompt_template = prompt_variant.configuration
215+
216+
prompt = PromptTemplate.from_prompty(pathlib.Path(__file__).parent.resolve() / prompt_template)
217+
prompt_messages = prompt.create_messages()
218+
219+
# Use RAG model, only if we were provided index and we have found a context there.
220+
if search_index_manager is not None:
221+
context = await search_index_manager.search(chat_request)
222+
if context:
223+
prompt_messages = PromptTemplate.from_string(
224+
'You are a helpful assistant that answers some questions '
225+
'with the help of some context data.\n\nHere is '
226+
'the context data:\n\n{{context}}').create_messages(data=dict(context=context))
227+
logger.info(f"{prompt_messages=}")
228+
else:
229+
logger.info("Unable to find the relevant information in the index for the request.")
230+
231+
try:
232+
response = await chat_client.complete(
233+
model=model_deployment_name, messages=prompt_messages + messages, stream=False
234+
)
235+
answer = response.choices[0].message.content
236+
except Exception as e:
237+
error = {"Error": str(e)}
238+
#track_event("ErrorLLM", targeting_id, error)
239+
answer = error
240+
241+
return { "answer": answer, "variant": prompt_variant.name if prompt_variant else None }
242+
222243

244+
# Inline Evaluation Prototype
223245

224246
# conversation = {}
225247

@@ -247,28 +269,28 @@ def get_targeting_context() -> TargetingContext:
247269

248270
# asyncio.create_task(run_evals(eval_input, targeting_id, project.scope, DefaultAzureCredential()))
249271

250-
return { "answer": answer, "variant": variant }
272+
# return { "answer": answer, "variant": variant }
251273

252274

253-
async def run_evals(eval_input, targeting_id, ai_project_scope, credential):
254-
run_eval(FluencyEvaluator, eval_input, targeting_id)
255-
run_eval(RelevanceEvaluator, eval_input, targeting_id)
256-
run_eval(CoherenceEvaluator, eval_input, targeting_id)
257-
258-
run_safety_eval(ViolenceEvaluator, eval_input, targeting_id, ai_project_scope, credential)
259-
run_safety_eval(SexualEvaluator, eval_input, targeting_id, ai_project_scope, credential)
260-
run_safety_eval(HateUnfairnessEvaluator, eval_input, targeting_id, ai_project_scope, credential)
261-
run_safety_eval(ProtectedMaterialEvaluator, eval_input, targeting_id, ai_project_scope, credential)
262-
run_safety_eval(ContentSafetyEvaluator, eval_input, targeting_id, ai_project_scope, credential)
263-
264-
def run_safety_eval(evaluator, eval_input, targeting_id, ai_project_scope, credential):
265-
eval = evaluator(credential=credential, azure_ai_project=ai_project_scope)
266-
score = eval(**eval_input)
267-
score.update({"evaluator_id": eval.id})
268-
track_event("gen.ai." + type(eval).__name__, targeting_id, score)
269-
270-
def run_eval(evaluator, eval_input, targeting_id):
271-
eval = evaluator(globals["model_config"])
272-
score = eval(**eval_input)
273-
score.update({"evaluator_id": evaluator.id})
274-
track_event("gen.ai." + evaluator.__name__, targeting_id, score)
275+
# async def run_evals(eval_input, targeting_id, ai_project_scope, credential):
276+
# run_eval(FluencyEvaluator, eval_input, targeting_id)
277+
# run_eval(RelevanceEvaluator, eval_input, targeting_id)
278+
# run_eval(CoherenceEvaluator, eval_input, targeting_id)
279+
280+
# run_safety_eval(ViolenceEvaluator, eval_input, targeting_id, ai_project_scope, credential)
281+
# run_safety_eval(SexualEvaluator, eval_input, targeting_id, ai_project_scope, credential)
282+
# run_safety_eval(HateUnfairnessEvaluator, eval_input, targeting_id, ai_project_scope, credential)
283+
# run_safety_eval(ProtectedMaterialEvaluator, eval_input, targeting_id, ai_project_scope, credential)
284+
# run_safety_eval(ContentSafetyEvaluator, eval_input, targeting_id, ai_project_scope, credential)
285+
286+
# def run_safety_eval(evaluator, eval_input, targeting_id, ai_project_scope, credential):
287+
# eval = evaluator(credential=credential, azure_ai_project=ai_project_scope)
288+
# score = eval(**eval_input)
289+
# score.update({"evaluator_id": eval.id})
290+
# track_event("gen.ai." + type(eval).__name__, targeting_id, score)
291+
292+
# def run_eval(evaluator, eval_input, targeting_id):
293+
# eval = evaluator(globals["model_config"])
294+
# score = eval(**eval_input)
295+
# score.update({"evaluator_id": evaluator.id})
296+
# track_event("gen.ai." + evaluator.__name__, targeting_id, score)

src/api/templates/index.html

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -146,43 +146,45 @@
146146

147147
try {
148148

149-
const result = await client.getStreamedCompletion(messages);
150-
151-
let answer = "";
152-
for await (const response of result) {
153-
if (!response.delta) {
154-
continue;
155-
}
156-
if (response.delta.content) {
157-
// Clear out the DIV if its the first answer chunk we've received
158-
if (answer == "") {
159-
messageDiv.innerHTML = "";
160-
}
161-
answer += response.delta.content;
162-
messageDiv.innerHTML = converter.makeHtml(answer);
163-
messageDiv.scrollIntoView();
164-
}
165-
if (response.error) {
166-
messageDiv.innerHTML = "Error: " + response.error;
167-
}
168-
}
169-
149+
// Uncomment the following lines if you want to use the streaming version
170150

151+
// const result = await client.getStreamedCompletion(messages, { "sessionState": { "sessionId": sessionId }});
152+
153+
// let answer = "";
154+
// for await (const response of result) {
155+
// if (!response.delta) {
156+
// continue;
157+
// }
158+
// if (response.delta.content) {
159+
// // Clear out the DIV if its the first answer chunk we've received
160+
// if (answer == "") {
161+
// messageDiv.innerHTML = "";
162+
// }
163+
// answer += response.delta.content;
164+
// messageDiv.innerHTML = converter.makeHtml(answer);
165+
// messageDiv.scrollIntoView();
166+
// }
167+
// if (response.error) {
168+
// messageDiv.innerHTML = "Error: " + response.error;
169+
// }
170+
// }
171171

172-
// const response = await client.getCompletion(messages, { "sessionState": { "sessionId": sessionId }});
173-
// const answer = response.answer;
172+
173+
// Uncomment the following lines if you want to use the non-streaming version
174+
const response = await client.getCompletion(messages, { "sessionState": { "sessionId": sessionId }});
175+
const answer = response.answer;
174176

175-
// messageDiv.innerHTML = converter.makeHtml(answer);
176-
// messageDiv.scrollIntoView();
177+
messageDiv.innerHTML = converter.makeHtml(answer);
178+
messageDiv.scrollIntoView();
177179

178180
messages.push({
179181
"role": "assistant",
180182
"content": answer
181183
});
182184

183-
// if (response.variant) {
184-
// messageTitleDiv.innerHTML += ` (Prompt Variant: ${response.variant})`;
185-
// }
185+
if (response.variant) {
186+
messageTitleDiv.innerHTML += ` (Prompt Variant: ${response.variant})`;
187+
}
186188

187189
messageInput.value = "";
188190
} catch (error) {

0 commit comments

Comments
 (0)