Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 70 additions & 51 deletions crates/core/src/api/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ fn emit_llm_start(
request_codec: Option<&dyn LlmCodec>,
) -> Result<()> {
ensure_runtime_owner()?;
let (event, subscribers) = {
let (entries, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
Expand All @@ -291,20 +291,27 @@ fn emit_llm_start(
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;

let sanitized_request = state.llm_sanitize_request_chain(request.clone(), &scope_locals);
let annotated_request = match request_codec {
Some(codec)
if sanitized_request.headers != request.headers
|| sanitized_request.content != request.content =>
{
codec.decode(&sanitized_request).ok().map(Arc::new)
}
_ => annotated_request,
};
let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null);
let event = state.build_llm_start_event(handle, Some(input), annotated_request);
(event, subscribers)
let entries = state.llm_sanitize_request_entries(&scope_locals);
(entries, subscribers)
};
let sanitized_request =
NemoRelayContextState::llm_sanitize_request_snapshot_chain(request.clone(), &entries);
let annotated_request = match request_codec {
Some(codec)
if sanitized_request.headers != request.headers
|| sanitized_request.content != request.content =>
{
codec.decode(&sanitized_request).ok().map(Arc::new)
}
_ => annotated_request,
};
let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null);
let event = {
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.build_llm_start_event(handle, Some(input), annotated_request)
};
NemoRelayContextState::emit_event(&event, &subscribers);
Ok(())
Expand Down Expand Up @@ -416,8 +423,7 @@ fn llm_call_end_with_behavior(
timestamp,
} = params;
ensure_runtime_owner()?;
let mut decode_error = None;
let (event, subscribers) = {
let (entries, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
Expand All @@ -429,32 +435,40 @@ fn llm_call_end_with_behavior(
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;

let sanitized_response = state.llm_sanitize_response_chain(response, &scope_locals);
let data = if sanitized_response.is_null() {
data
} else {
Some(sanitized_response)
};
let annotated_response = match annotated_response {
Some(annotated_response) => Some(annotated_response),
None => match (response_codec.as_ref(), data.as_ref()) {
(Some(codec), Some(response)) => match codec.decode_response(response) {
Ok(mut decoded) => {
if behavior.attach_estimated_cost {
attach_estimated_cost_for_provider(&mut decoded, Some(&handle.name));
}
Some(Arc::new(decoded))
}
Err(error) => {
decode_error = Some(error);
None
let entries = state.llm_sanitize_response_entries(&scope_locals);
(entries, subscribers)
};
let sanitized_response =
NemoRelayContextState::llm_sanitize_response_snapshot_chain(response, &entries);
let data = if sanitized_response.is_null() {
data
} else {
Some(sanitized_response)
};
let mut decode_error = None;
let annotated_response = match annotated_response {
Some(annotated_response) => Some(annotated_response),
None => match (response_codec.as_ref(), data.as_ref()) {
(Some(codec), Some(response)) => match codec.decode_response(response) {
Ok(mut decoded) => {
if behavior.attach_estimated_cost {
attach_estimated_cost_for_provider(&mut decoded, Some(&handle.name));
}
},
_ => None,
Some(Arc::new(decoded))
}
Err(error) => {
decode_error = Some(error);
None
}
},
};

_ => None,
},
};
let event = {
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let end_metadata = metadata_with_otel_status(metadata, "OK", None);
let event = state.build_llm_end_event(
EndLlmHandleParams::builder()
Expand All @@ -465,7 +479,7 @@ fn llm_call_end_with_behavior(
.timestamp_opt(timestamp)
.build(),
);
(event, subscribers)
event
};
NemoRelayContextState::emit_event(&event, &subscribers);
if let Some(error) = decode_error
Expand Down Expand Up @@ -825,15 +839,20 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu
/// this helper.
pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result<LlmRequest> {
ensure_runtime_owner()?;
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals =
scope_guard.collect_scope_local_registries(|registries| &registries.llm_request_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let (request, _) = state.llm_request_intercepts_chain(name, request, None, &scope_locals)?;
let entries = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard
.collect_scope_local_registries(|registries| &registries.llm_request_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.llm_request_intercept_entries(&scope_locals)
};
let (request, _) = NemoRelayContextState::llm_request_intercepts_snapshot_chain(
name, request, None, &entries,
)?;
Ok(request)
}

Expand Down
Loading