diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 7f2f5d51..6c6bc97a 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -279,7 +279,7 @@ fn emit_llm_start( request_codec: Option<&dyn LlmCodec>, ) -> Result<()> { ensure_runtime_owner()?; - let (event, subscribers) = { + let (entries, subscribers) = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard.collect_scope_local_registries(|registries| { @@ -291,20 +291,27 @@ fn emit_llm_start( let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - - let sanitized_request = state.llm_sanitize_request_chain(request.clone(), &scope_locals); - let annotated_request = match request_codec { - Some(codec) - if sanitized_request.headers != request.headers - || sanitized_request.content != request.content => - { - codec.decode(&sanitized_request).ok().map(Arc::new) - } - _ => annotated_request, - }; - let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null); - let event = state.build_llm_start_event(handle, Some(input), annotated_request); - (event, subscribers) + let entries = state.llm_sanitize_request_entries(&scope_locals); + (entries, subscribers) + }; + let sanitized_request = + NemoRelayContextState::llm_sanitize_request_snapshot_chain(request.clone(), &entries); + let annotated_request = match request_codec { + Some(codec) + if sanitized_request.headers != request.headers + || sanitized_request.content != request.content => + { + codec.decode(&sanitized_request).ok().map(Arc::new) + } + _ => annotated_request, + }; + let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null); + let event = { + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; + state.build_llm_start_event(handle, Some(input), annotated_request) }; NemoRelayContextState::emit_event(&event, &subscribers); Ok(()) @@ -416,8 +423,7 @@ fn llm_call_end_with_behavior( timestamp, } = params; ensure_runtime_owner()?; - let mut decode_error = None; - let (event, subscribers) = { + let (entries, subscribers) = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard.collect_scope_local_registries(|registries| { @@ -429,32 +435,40 @@ fn llm_call_end_with_behavior( let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - - let sanitized_response = state.llm_sanitize_response_chain(response, &scope_locals); - let data = if sanitized_response.is_null() { - data - } else { - Some(sanitized_response) - }; - let annotated_response = match annotated_response { - Some(annotated_response) => Some(annotated_response), - None => match (response_codec.as_ref(), data.as_ref()) { - (Some(codec), Some(response)) => match codec.decode_response(response) { - Ok(mut decoded) => { - if behavior.attach_estimated_cost { - attach_estimated_cost_for_provider(&mut decoded, Some(&handle.name)); - } - Some(Arc::new(decoded)) - } - Err(error) => { - decode_error = Some(error); - None + let entries = state.llm_sanitize_response_entries(&scope_locals); + (entries, subscribers) + }; + let sanitized_response = + NemoRelayContextState::llm_sanitize_response_snapshot_chain(response, &entries); + let data = if sanitized_response.is_null() { + data + } else { + Some(sanitized_response) + }; + let mut decode_error = None; + let annotated_response = match annotated_response { + Some(annotated_response) => Some(annotated_response), + None => match (response_codec.as_ref(), data.as_ref()) { + (Some(codec), Some(response)) => match codec.decode_response(response) { + Ok(mut decoded) => { + if behavior.attach_estimated_cost { + attach_estimated_cost_for_provider(&mut decoded, Some(&handle.name)); } - }, - _ => None, + Some(Arc::new(decoded)) + } + Err(error) => { + decode_error = Some(error); + None + } }, - }; - + _ => None, + }, + }; + let event = { + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; let end_metadata = metadata_with_otel_status(metadata, "OK", None); let event = state.build_llm_end_event( EndLlmHandleParams::builder() @@ -465,7 +479,7 @@ fn llm_call_end_with_behavior( .timestamp_opt(timestamp) .build(), ); - (event, subscribers) + event }; NemoRelayContextState::emit_event(&event, &subscribers); if let Some(error) = decode_error @@ -825,15 +839,20 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu /// this helper. pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result { ensure_runtime_owner()?; - let scope_stack = current_scope_stack(); - let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); - let scope_locals = - scope_guard.collect_scope_local_registries(|registries| ®istries.llm_request_intercepts); - let context = global_context(); - let state = context - .read() - .map_err(|error| FlowError::Internal(error.to_string()))?; - let (request, _) = state.llm_request_intercepts_chain(name, request, None, &scope_locals)?; + let entries = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + let scope_locals = scope_guard + .collect_scope_local_registries(|registries| ®istries.llm_request_intercepts); + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; + state.llm_request_intercept_entries(&scope_locals) + }; + let (request, _) = NemoRelayContextState::llm_request_intercepts_snapshot_chain( + name, request, None, &entries, + )?; Ok(request) } diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index ee69b6f6..70276786 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -592,23 +592,39 @@ impl NemoRelayContextState { Self::emit_event(&event, subscribers); } - /// Run tool request sanitizers across global and scope-local registries. + /// Snapshot tool request sanitizers in priority order. /// /// # Parameters - /// - `name`: Tool name associated with the request. - /// - `args`: Raw tool arguments to sanitize for observability. /// - `scope_locals`: Scope-local sanitizer registries collected from the /// active scope stack. /// /// # Returns - /// The sanitized JSON payload after every matching guardrail has run. - pub(crate) fn tool_sanitize_request_chain( + /// Named sanitizer snapshots that can be evaluated after registry locks + /// are released. + pub(crate) fn tool_sanitize_request_entries( &self, + scope_locals: &[&SortedRegistry>], + ) -> Vec> { + merge_guardrail_entries(&self.tool_sanitize_request_guardrails, scope_locals) + .into_iter() + .cloned() + .collect() + } + + /// Run a snapshot of tool request sanitizers in priority order. + /// + /// # Parameters + /// - `name`: Tool name associated with the request. + /// - `args`: Raw tool arguments to sanitize for observability. + /// - `entries`: Sanitizer snapshots to evaluate. + /// + /// # Returns + /// The sanitized JSON payload after every provided guardrail has run. + pub(crate) fn tool_sanitize_request_snapshot_chain( name: &str, args: Json, - scope_locals: &[&SortedRegistry>], + entries: &[Guardrail], ) -> Json { - let entries = merge_guardrail_entries(&self.tool_sanitize_request_guardrails, scope_locals); let mut value = args; for entry in entries { value = (entry.payload)(name, value); @@ -616,24 +632,39 @@ impl NemoRelayContextState { value } - /// Run tool response sanitizers across global and scope-local registries. + /// Snapshot tool response sanitizers in priority order. /// /// # Parameters - /// - `name`: Tool name associated with the response. - /// - `result`: Raw tool result to sanitize for observability. /// - `scope_locals`: Scope-local sanitizer registries collected from the /// active scope stack. /// /// # Returns - /// The sanitized JSON payload after every matching guardrail has run. - pub(crate) fn tool_sanitize_response_chain( + /// Named sanitizer snapshots that can be evaluated after registry locks + /// are released. + pub(crate) fn tool_sanitize_response_entries( &self, + scope_locals: &[&SortedRegistry>], + ) -> Vec> { + merge_guardrail_entries(&self.tool_sanitize_response_guardrails, scope_locals) + .into_iter() + .cloned() + .collect() + } + + /// Run a snapshot of tool response sanitizers in priority order. + /// + /// # Parameters + /// - `name`: Tool name associated with the response. + /// - `result`: Raw tool result to sanitize for observability. + /// - `entries`: Sanitizer snapshots to evaluate. + /// + /// # Returns + /// The sanitized JSON payload after every provided guardrail has run. + pub(crate) fn tool_sanitize_response_snapshot_chain( name: &str, result: Json, - scope_locals: &[&SortedRegistry>], + entries: &[Guardrail], ) -> Json { - let entries = - merge_guardrail_entries(&self.tool_sanitize_response_guardrails, scope_locals); let mut value = result; for entry in entries { value = (entry.payload)(name, value); @@ -729,15 +760,33 @@ impl NemoRelayContextState { Ok(None) } - /// Run tool request intercepts in priority order. + /// Snapshot tool request intercepts in priority order. /// /// # Parameters - /// - `name`: Tool name associated with the request. - /// - `args`: Tool arguments to pass through the intercept chain. /// - `scope_locals`: Scope-local request intercept registries collected /// from the active scope stack. /// /// # Returns + /// Named intercept snapshots that can be evaluated after registry locks + /// are released. + pub(crate) fn tool_request_intercept_entries( + &self, + scope_locals: &[&SortedRegistry>], + ) -> Vec> { + merge_intercept_entries(&self.tool_request_intercepts, scope_locals) + .into_iter() + .cloned() + .collect() + } + + /// Run a snapshot of tool request intercepts in priority order. + /// + /// # Parameters + /// - `name`: Tool name associated with the request. + /// - `args`: Tool arguments to pass through the intercept chain. + /// - `entries`: Intercept snapshots to evaluate. + /// + /// # Returns /// A [`Result`] containing the final JSON argument payload. /// /// # Errors @@ -746,13 +795,11 @@ impl NemoRelayContextState { /// # Notes /// If an intercept entry has `break_chain` enabled, later intercepts are /// skipped after that entry runs. - pub(crate) fn tool_request_intercepts_chain( - &self, + pub(crate) fn tool_request_intercepts_snapshot_chain( name: &str, args: Json, - scope_locals: &[&SortedRegistry>], + entries: &[Intercept], ) -> crate::error::Result { - let entries = merge_intercept_entries(&self.tool_request_intercepts, scope_locals); let mut value = args; for entry in entries { value = (entry.payload.callable)(name, value)?; @@ -792,21 +839,37 @@ impl NemoRelayContextState { next } - /// Run LLM request sanitizers across global and scope-local registries. + /// Snapshot LLM request sanitizers in priority order. /// /// # Parameters - /// - `request`: Raw LLM request to sanitize for observability. /// - `scope_locals`: Scope-local sanitizer registries collected from the /// active scope stack. /// /// # Returns - /// The sanitized [`LlmRequest`] after every matching guardrail has run. - pub(crate) fn llm_sanitize_request_chain( + /// Named sanitizer snapshots that can be evaluated after registry locks + /// are released. + pub(crate) fn llm_sanitize_request_entries( &self, - request: LlmRequest, scope_locals: &[&SortedRegistry>], + ) -> Vec> { + merge_guardrail_entries(&self.llm_sanitize_request_guardrails, scope_locals) + .into_iter() + .cloned() + .collect() + } + + /// Run a snapshot of LLM request sanitizers in priority order. + /// + /// # Parameters + /// - `request`: Raw LLM request to sanitize for observability. + /// - `entries`: Sanitizer snapshots to evaluate. + /// + /// # Returns + /// The sanitized [`LlmRequest`] after every provided guardrail has run. + pub(crate) fn llm_sanitize_request_snapshot_chain( + request: LlmRequest, + entries: &[Guardrail], ) -> LlmRequest { - let entries = merge_guardrail_entries(&self.llm_sanitize_request_guardrails, scope_locals); let mut value = request; for entry in entries { value = (entry.payload)(value); @@ -814,21 +877,37 @@ impl NemoRelayContextState { value } - /// Run LLM response sanitizers across global and scope-local registries. + /// Snapshot LLM response sanitizers in priority order. /// /// # Parameters - /// - `response`: Raw response payload to sanitize for observability. /// - `scope_locals`: Scope-local sanitizer registries collected from the /// active scope stack. /// /// # Returns - /// The sanitized response payload after every matching guardrail has run. - pub(crate) fn llm_sanitize_response_chain( + /// Named sanitizer snapshots that can be evaluated after registry locks + /// are released. + pub(crate) fn llm_sanitize_response_entries( &self, - response: Json, scope_locals: &[&SortedRegistry>], + ) -> Vec> { + merge_guardrail_entries(&self.llm_sanitize_response_guardrails, scope_locals) + .into_iter() + .cloned() + .collect() + } + + /// Run a snapshot of LLM response sanitizers in priority order. + /// + /// # Parameters + /// - `response`: Raw response payload to sanitize for observability. + /// - `entries`: Sanitizer snapshots to evaluate. + /// + /// # Returns + /// The sanitized response payload after every provided guardrail has run. + pub(crate) fn llm_sanitize_response_snapshot_chain( + response: Json, + entries: &[Guardrail], ) -> Json { - let entries = merge_guardrail_entries(&self.llm_sanitize_response_guardrails, scope_locals); let mut value = response; for entry in entries { value = (entry.payload)(value); @@ -921,15 +1000,33 @@ impl NemoRelayContextState { Ok(None) } - /// Run LLM request intercepts in priority order. + /// Snapshot LLM request intercepts in priority order. + /// + /// # Parameters + /// - `scope_locals`: Scope-local request intercept registries collected + /// from the active scope stack. + /// + /// # Returns + /// Named intercept snapshots that can be evaluated after registry locks + /// are released. + pub(crate) fn llm_request_intercept_entries( + &self, + scope_locals: &[&SortedRegistry>], + ) -> Vec> { + merge_intercept_entries(&self.llm_request_intercepts, scope_locals) + .into_iter() + .cloned() + .collect() + } + + /// Run a snapshot of LLM request intercepts in priority order. /// /// # Parameters /// - `name`: Logical provider or model family name. /// - `request`: LLM request to pass through the intercept chain. /// - `annotated`: Optional normalized request annotation to carry through /// the chain. - /// - `scope_locals`: Scope-local request intercept registries collected - /// from the active scope stack. + /// - `entries`: Intercept snapshots to evaluate. /// /// # Returns /// A [`Result`] containing the final request and annotation pair. @@ -940,14 +1037,12 @@ impl NemoRelayContextState { /// # Notes /// If an intercept entry has `break_chain` enabled, later intercepts are /// skipped after that entry runs. - pub(crate) fn llm_request_intercepts_chain( - &self, + pub(crate) fn llm_request_intercepts_snapshot_chain( name: &str, request: LlmRequest, annotated: Option, - scope_locals: &[&SortedRegistry>], + entries: &[Intercept], ) -> crate::error::Result<(LlmRequest, Option)> { - let entries = merge_intercept_entries(&self.llm_request_intercepts, scope_locals); let mut request_value = request; let mut annotated_value = annotated; for entry in entries { diff --git a/crates/core/src/api/shared.rs b/crates/core/src/api/shared.rs index 52f0da5b..861cd41c 100644 --- a/crates/core/src/api/shared.rs +++ b/crates/core/src/api/shared.rs @@ -75,24 +75,29 @@ pub(crate) fn run_request_intercepts_with_codec( request: LlmRequest, codec: Option>, ) -> Result<(LlmRequest, Option>)> { - let scope_stack = current_scope_stack(); - let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); - let scope_locals = - scope_guard.collect_scope_local_registries(|registries| ®istries.llm_request_intercepts); - - let context = global_context(); - let state = context - .read() - .map_err(|error| FlowError::Internal(error.to_string()))?; - let original = request.clone(); let annotated = match &codec { Some(codec) => Some(codec.decode(&request)?), None => None, }; + let entries = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + let scope_locals = scope_guard + .collect_scope_local_registries(|registries| ®istries.llm_request_intercepts); + + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; + state.llm_request_intercept_entries(&scope_locals) + }; + let (intercepted_request, intercepted_annotated) = - state.llm_request_intercepts_chain(name, request, annotated, &scope_locals)?; + crate::api::runtime::NemoRelayContextState::llm_request_intercepts_snapshot_chain( + name, request, annotated, &entries, + )?; match (codec, intercepted_annotated) { (Some(codec), Some(annotated)) => { diff --git a/crates/core/src/api/tool.rs b/crates/core/src/api/tool.rs index c01e4be9..bf9dad77 100644 --- a/crates/core/src/api/tool.rs +++ b/crates/core/src/api/tool.rs @@ -214,7 +214,7 @@ pub struct ToolCallEndParams<'a> { pub fn tool_call(params: ToolCallParams<'_>) -> Result { ensure_runtime_owner()?; let parent_uuid = resolve_parent_uuid(params.parent); - let (handle, event, subscribers) = { + let (entries, subscribers) = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard.collect_scope_local_registries(|registries| { @@ -226,9 +226,19 @@ pub fn tool_call(params: ToolCallParams<'_>) -> Result { let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - - let sanitized_args = - state.tool_sanitize_request_chain(params.name, params.args, &scope_locals); + let entries = state.tool_sanitize_request_entries(&scope_locals); + (entries, subscribers) + }; + let sanitized_args = NemoRelayContextState::tool_sanitize_request_snapshot_chain( + params.name, + params.args, + &entries, + ); + let (handle, event) = { + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; let handle_params = CreateToolHandleParams::builder() .name(params.name) .parent_uuid_opt(parent_uuid) @@ -240,7 +250,7 @@ pub fn tool_call(params: ToolCallParams<'_>) -> Result { .build(); let handle = state.create_tool_handle(handle_params); let event = state.build_tool_start_event(&handle, Some(sanitized_args)); - (handle, event, subscribers) + (handle, event) }; NemoRelayContextState::emit_event(&event, &subscribers); Ok(handle) @@ -274,7 +284,7 @@ pub fn tool_call(params: ToolCallParams<'_>) -> Result { /// the caller-owned `result` value. pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { ensure_runtime_owner()?; - let (event, subscribers) = { + let (entries, subscribers) = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard.collect_scope_local_registries(|registries| { @@ -286,14 +296,24 @@ pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - - let sanitized_result = - state.tool_sanitize_response_chain(¶ms.handle.name, params.result, &scope_locals); - let data = if sanitized_result.is_null() { - params.data - } else { - Some(sanitized_result) - }; + let entries = state.tool_sanitize_response_entries(&scope_locals); + (entries, subscribers) + }; + let sanitized_result = NemoRelayContextState::tool_sanitize_response_snapshot_chain( + ¶ms.handle.name, + params.result, + &entries, + ); + let data = if sanitized_result.is_null() { + params.data + } else { + Some(sanitized_result) + }; + let event = { + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; let event = state.build_tool_end_event( EndToolHandleParams::builder() .handle(params.handle) @@ -302,7 +322,7 @@ pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { .timestamp_opt(params.timestamp) .build(), ); - (event, subscribers) + event }; NemoRelayContextState::emit_event(&event, &subscribers); Ok(()) @@ -411,7 +431,7 @@ pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result { } } - let intercepted_args = { + let intercept_entries = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard @@ -420,8 +440,13 @@ pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result { let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - state.tool_request_intercepts_chain(&name, args, &scope_locals)? + state.tool_request_intercept_entries(&scope_locals) }; + let intercepted_args = NemoRelayContextState::tool_request_intercepts_snapshot_chain( + &name, + args, + &intercept_entries, + )?; let handle = tool_call( ToolCallParams::builder() @@ -487,15 +512,18 @@ pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result { /// Conditional guardrails and execution intercepts are not run by this helper. pub fn tool_request_intercepts(name: &str, args: Json) -> Result { ensure_runtime_owner()?; - let scope_stack = current_scope_stack(); - let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); - let scope_locals = scope_guard - .collect_scope_local_registries(|registries| ®istries.tool_request_intercepts); - let context = global_context(); - let state = context - .read() - .map_err(|error| FlowError::Internal(error.to_string()))?; - state.tool_request_intercepts_chain(name, args, &scope_locals) + let entries = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + let scope_locals = scope_guard + .collect_scope_local_registries(|registries| ®istries.tool_request_intercepts); + let context = global_context(); + let state = context + .read() + .map_err(|error| FlowError::Internal(error.to_string()))?; + state.tool_request_intercept_entries(&scope_locals) + }; + NemoRelayContextState::tool_request_intercepts_snapshot_chain(name, args, &entries) } /// Run only the tool conditional-execution guardrail chain. diff --git a/crates/core/src/stream.rs b/crates/core/src/stream.rs index 65b46f47..c2dde21f 100644 --- a/crates/core/src/stream.rs +++ b/crates/core/src/stream.rs @@ -151,7 +151,7 @@ impl LlmStreamWrapper { None => Json::Null, }; - let event_snapshot = { + let snapshot = { let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned"); let sl = ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails); @@ -161,32 +161,42 @@ impl LlmStreamWrapper { match state { Ok(state) => { let subscribers = state.collect_event_subscribers(&sl_subs); - let sanitized = state.llm_sanitize_response_chain(aggregated, &sl); - let data = if sanitized.is_null() { - self.handle.data.clone() - } else { - Some(sanitized) - }; - let annotated_response: Option> = self - .response_codec - .as_ref() - .and_then(|codec| { - let mut decoded = codec.decode_response(data.as_ref()?).ok()?; - attach_estimated_cost_for_provider( - &mut decoded, - Some(&self.handle.name), - ); - Some(decoded) - }) - .map(Arc::new); - let event = - state.end_llm_handle(&self.handle, data, metadata, annotated_response); - Some((event, subscribers)) + let entries = state.llm_sanitize_response_entries(&sl); + Some((entries, subscribers)) } Err(_) => None, } }; - if let Some((event, subscribers)) = event_snapshot { + let Some((entries, subscribers)) = snapshot else { + return; + }; + let sanitized = + NemoRelayContextState::llm_sanitize_response_snapshot_chain(aggregated, &entries); + let data = if sanitized.is_null() { + self.handle.data.clone() + } else { + Some(sanitized) + }; + let annotated_response: Option> = self + .response_codec + .as_ref() + .and_then(|codec| { + let mut decoded = codec.decode_response(data.as_ref()?).ok()?; + attach_estimated_cost_for_provider(&mut decoded, Some(&self.handle.name)); + Some(decoded) + }) + .map(Arc::new); + let event_snapshot = { + let ctx = global_context(); + let state = ctx.read(); + match state { + Ok(state) => { + Some(state.end_llm_handle(&self.handle, data, metadata, annotated_response)) + } + Err(_) => None, + } + }; + if let Some(event) = event_snapshot { NemoRelayContextState::emit_event(&event, &subscribers); } } diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index df0eb59f..612c5a83 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -22,22 +22,29 @@ use nemo_relay::api::llm::{ }; use nemo_relay::api::registry::{ deregister_llm_conditional_execution_guardrail, deregister_llm_execution_intercept, - deregister_llm_request_intercept, deregister_llm_stream_execution_intercept, + deregister_llm_request_intercept, deregister_llm_sanitize_request_guardrail, + deregister_llm_sanitize_response_guardrail, deregister_llm_stream_execution_intercept, deregister_tool_conditional_execution_guardrail, deregister_tool_execution_intercept, deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, + register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, register_tool_execution_intercept, register_tool_request_intercept, register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, - scope_register_tool_execution_intercept, scope_register_tool_sanitize_request_guardrail, + scope_register_llm_conditional_execution_guardrail, scope_register_llm_execution_intercept, + scope_register_llm_request_intercept, scope_register_llm_sanitize_request_guardrail, + scope_register_llm_sanitize_response_guardrail, scope_register_llm_stream_execution_intercept, + scope_register_tool_conditional_execution_guardrail, scope_register_tool_execution_intercept, + scope_register_tool_request_intercept, scope_register_tool_sanitize_request_guardrail, + scope_register_tool_sanitize_response_guardrail, }; use nemo_relay::api::runtime::NemoRelayContextState; use nemo_relay::api::runtime::global_context; use nemo_relay::api::runtime::{ LlmExecutionNextFn, LlmJsonStream, LlmStreamExecutionNextFn, ToolExecutionNextFn, }; -use nemo_relay::api::runtime::{create_scope_stack, set_thread_scope_stack}; +use nemo_relay::api::runtime::{create_scope_stack, current_scope_stack, set_thread_scope_stack}; use nemo_relay::api::scope::{ScopeHandle, ScopeType}; use nemo_relay::api::scope::{pop_scope, push_scope}; use nemo_relay::api::subscriber::{deregister_subscriber, flush_subscribers, register_subscriber}; @@ -85,6 +92,35 @@ fn captured_events_snapshot(events: &Arc>>) -> Vec { events.lock().unwrap().clone() } +fn assert_middleware_callback_locks_are_free() { + let context = global_context(); + assert!( + context.try_write().is_ok(), + "middleware callback ran while the global registry lock was held" + ); + + let scope_stack = current_scope_stack(); + assert!( + scope_stack.try_write().is_ok(), + "middleware callback ran while the scope stack lock was held" + ); +} + +fn record_middleware_callback(callbacks: &Arc>>, label: &'static str) { + callbacks.lock().unwrap().push(label); +} + +fn assert_middleware_callback_labels( + callbacks: &Arc>>, + expected: &[&'static str], +) { + let mut actual = callbacks.lock().unwrap().clone(); + let mut expected = expected.to_vec(); + actual.sort_unstable(); + expected.sort_unstable(); + assert_eq!(actual, expected); +} + // ========================================================================= // Priority Ordering Tests // ========================================================================= @@ -1580,6 +1616,549 @@ fn test_concurrent_register_and_read() { } } +// ========================================================================= +// Lock Regression Tests +// ========================================================================= + +#[test] +fn test_tool_request_intercept_registry_mutations_apply_to_later_calls() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let callbacks = Arc::new(Mutex::new(Vec::<&'static str>::new())); + let late_registered = Arc::new(AtomicBool::new(false)); + + let tracked = callbacks.clone(); + let registered = late_registered.clone(); + register_tool_request_intercept( + "snapshot_tool_request_initial", + 1, + false, + Arc::new(move |_, args| { + record_middleware_callback(&tracked, "tool_request_initial"); + assert_middleware_callback_locks_are_free(); + + if !registered.swap(true, Ordering::SeqCst) { + let tracked = tracked.clone(); + register_tool_request_intercept( + "snapshot_tool_request_late", + 2, + false, + Arc::new(move |_, args| { + record_middleware_callback(&tracked, "tool_request_late"); + assert_middleware_callback_locks_are_free(); + Ok(args) + }), + ) + .unwrap(); + } + + Ok(args) + }), + ) + .unwrap(); + + let args = tool_request_intercepts("tool", json!({"round": 1})).unwrap(); + assert_eq!(args["round"], 1); + assert_middleware_callback_labels(&callbacks, &["tool_request_initial"]); + + callbacks.lock().unwrap().clear(); + let args = tool_request_intercepts("tool", json!({"round": 2})).unwrap(); + assert_eq!(args["round"], 2); + assert_middleware_callback_labels(&callbacks, &["tool_request_initial", "tool_request_late"]); + + deregister_tool_request_intercept("snapshot_tool_request_initial").unwrap(); + deregister_tool_request_intercept("snapshot_tool_request_late").unwrap(); +} + +#[test] +fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let callbacks = Arc::new(Mutex::new(Vec::<&'static str>::new())); + let late_registered = Arc::new(AtomicBool::new(false)); + + let tracked = callbacks.clone(); + let registered = late_registered.clone(); + register_llm_request_intercept( + "snapshot_llm_request_initial", + 1, + false, + Arc::new(move |_, request, annotated| { + record_middleware_callback(&tracked, "llm_request_initial"); + assert_middleware_callback_locks_are_free(); + + if !registered.swap(true, Ordering::SeqCst) { + let tracked = tracked.clone(); + register_llm_request_intercept( + "snapshot_llm_request_late", + 2, + false, + Arc::new(move |_, request, annotated| { + record_middleware_callback(&tracked, "llm_request_late"); + assert_middleware_callback_locks_are_free(); + Ok((request, annotated)) + }), + ) + .unwrap(); + } + + Ok((request, annotated)) + }), + ) + .unwrap(); + + let request = llm_request_intercepts( + "llm", + LlmRequest { + headers: serde_json::Map::new(), + content: json!({"round": 1}), + }, + ) + .unwrap(); + assert_eq!(request.content["round"], 1); + assert_middleware_callback_labels(&callbacks, &["llm_request_initial"]); + + callbacks.lock().unwrap().clear(); + let request = llm_request_intercepts( + "llm", + LlmRequest { + headers: serde_json::Map::new(), + content: json!({"round": 2}), + }, + ) + .unwrap(); + assert_eq!(request.content["round"], 2); + assert_middleware_callback_labels(&callbacks, &["llm_request_initial", "llm_request_late"]); + + deregister_llm_request_intercept("snapshot_llm_request_initial").unwrap(); + deregister_llm_request_intercept("snapshot_llm_request_late").unwrap(); +} + +#[tokio::test] +async fn test_tool_middleware_callbacks_run_without_registry_or_scope_locks() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + let scope = setup_isolated_scope("tool_lock_regression"); + let callbacks = Arc::new(Mutex::new(Vec::<&'static str>::new())); + + let tracked = callbacks.clone(); + register_tool_conditional_execution_guardrail( + "lock_global_tool_conditional", + 1, + Arc::new(move |_, _| { + record_middleware_callback(&tracked, "tool_conditional_global"); + assert_middleware_callback_locks_are_free(); + Ok(None) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_tool_conditional_execution_guardrail( + &scope.uuid, + "lock_scope_tool_conditional", + 2, + Arc::new(move |_, _| { + record_middleware_callback(&tracked, "tool_conditional_scope"); + assert_middleware_callback_locks_are_free(); + Ok(None) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_tool_request_intercept( + "lock_global_tool_request", + 1, + false, + Arc::new(move |_, args| { + record_middleware_callback(&tracked, "tool_request_global"); + assert_middleware_callback_locks_are_free(); + Ok(args) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_tool_request_intercept( + &scope.uuid, + "lock_scope_tool_request", + 2, + false, + Arc::new(move |_, args| { + record_middleware_callback(&tracked, "tool_request_scope"); + assert_middleware_callback_locks_are_free(); + Ok(args) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_tool_sanitize_request_guardrail( + "lock_global_tool_sanitize_request", + 1, + Arc::new(move |_, args| { + record_middleware_callback(&tracked, "tool_sanitize_request_global"); + assert_middleware_callback_locks_are_free(); + args + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_tool_sanitize_request_guardrail( + &scope.uuid, + "lock_scope_tool_sanitize_request", + 2, + Arc::new(move |_, args| { + record_middleware_callback(&tracked, "tool_sanitize_request_scope"); + assert_middleware_callback_locks_are_free(); + args + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_tool_execution_intercept( + "lock_global_tool_execution", + 1, + Arc::new(move |_, args, next| { + record_middleware_callback(&tracked, "tool_execution_global"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { next(args).await }) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_tool_execution_intercept( + &scope.uuid, + "lock_scope_tool_execution", + 2, + Arc::new(move |_, args, next| { + record_middleware_callback(&tracked, "tool_execution_scope"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { next(args).await }) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_tool_sanitize_response_guardrail( + "lock_global_tool_sanitize_response", + 1, + Arc::new(move |_, result| { + record_middleware_callback(&tracked, "tool_sanitize_response_global"); + assert_middleware_callback_locks_are_free(); + result + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_tool_sanitize_response_guardrail( + &scope.uuid, + "lock_scope_tool_sanitize_response", + 2, + Arc::new(move |_, result| { + record_middleware_callback(&tracked, "tool_sanitize_response_scope"); + assert_middleware_callback_locks_are_free(); + result + }), + ) + .unwrap(); + + let tracked = callbacks.clone(); + let func: ToolExecutionNextFn = Arc::new(move |args| { + record_middleware_callback(&tracked, "tool_func"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { Ok(args) }) + }); + let result = tool_call_execute( + nemo_relay::api::tool::ToolCallExecuteParams::builder() + .name("tool") + .args(json!({"ok": true})) + .func(func) + .build(), + ) + .await + .unwrap(); + assert_eq!(result["ok"], true); + assert_middleware_callback_labels( + &callbacks, + &[ + "tool_conditional_global", + "tool_conditional_scope", + "tool_request_global", + "tool_request_scope", + "tool_sanitize_request_global", + "tool_sanitize_request_scope", + "tool_execution_global", + "tool_execution_scope", + "tool_func", + "tool_sanitize_response_global", + "tool_sanitize_response_scope", + ], + ); + + deregister_tool_conditional_execution_guardrail("lock_global_tool_conditional").unwrap(); + deregister_tool_request_intercept("lock_global_tool_request").unwrap(); + deregister_tool_sanitize_request_guardrail("lock_global_tool_sanitize_request").unwrap(); + deregister_tool_execution_intercept("lock_global_tool_execution").unwrap(); + deregister_tool_sanitize_response_guardrail("lock_global_tool_sanitize_response").unwrap(); + pop_scope( + nemo_relay::api::scope::PopScopeParams::builder() + .handle_uuid(&scope.uuid) + .build(), + ) + .unwrap(); +} + +#[tokio::test] +async fn test_llm_middleware_callbacks_run_without_registry_or_scope_locks() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + let scope = setup_isolated_scope("llm_lock_regression"); + let callbacks = Arc::new(Mutex::new(Vec::<&'static str>::new())); + + let tracked = callbacks.clone(); + register_llm_conditional_execution_guardrail( + "lock_global_llm_conditional", + 1, + Arc::new(move |_| { + record_middleware_callback(&tracked, "llm_conditional_global"); + assert_middleware_callback_locks_are_free(); + Ok(None) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_llm_conditional_execution_guardrail( + &scope.uuid, + "lock_scope_llm_conditional", + 2, + Arc::new(move |_| { + record_middleware_callback(&tracked, "llm_conditional_scope"); + assert_middleware_callback_locks_are_free(); + Ok(None) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_llm_request_intercept( + "lock_global_llm_request", + 1, + false, + Arc::new(move |_, request, annotated| { + record_middleware_callback(&tracked, "llm_request_global"); + assert_middleware_callback_locks_are_free(); + Ok((request, annotated)) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_llm_request_intercept( + &scope.uuid, + "lock_scope_llm_request", + 2, + false, + Arc::new(move |_, request, annotated| { + record_middleware_callback(&tracked, "llm_request_scope"); + assert_middleware_callback_locks_are_free(); + Ok((request, annotated)) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_llm_sanitize_request_guardrail( + "lock_global_llm_sanitize_request", + 1, + Arc::new(move |request| { + record_middleware_callback(&tracked, "llm_sanitize_request_global"); + assert_middleware_callback_locks_are_free(); + request + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_llm_sanitize_request_guardrail( + &scope.uuid, + "lock_scope_llm_sanitize_request", + 2, + Arc::new(move |request| { + record_middleware_callback(&tracked, "llm_sanitize_request_scope"); + assert_middleware_callback_locks_are_free(); + request + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_llm_execution_intercept( + "lock_global_llm_execution", + 1, + Arc::new(move |_, request, next| { + record_middleware_callback(&tracked, "llm_execution_global"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { next(request).await }) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_llm_execution_intercept( + &scope.uuid, + "lock_scope_llm_execution", + 2, + Arc::new(move |_, request, next| { + record_middleware_callback(&tracked, "llm_execution_scope"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { next(request).await }) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_llm_stream_execution_intercept( + "lock_global_llm_stream_execution", + 1, + Arc::new(move |_, request, next| { + record_middleware_callback(&tracked, "llm_stream_execution_global"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { next(request).await }) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_llm_stream_execution_intercept( + &scope.uuid, + "lock_scope_llm_stream_execution", + 2, + Arc::new(move |_, request, next| { + record_middleware_callback(&tracked, "llm_stream_execution_scope"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { next(request).await }) + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + register_llm_sanitize_response_guardrail( + "lock_global_llm_sanitize_response", + 1, + Arc::new(move |response| { + record_middleware_callback(&tracked, "llm_sanitize_response_global"); + assert_middleware_callback_locks_are_free(); + response + }), + ) + .unwrap(); + let tracked = callbacks.clone(); + scope_register_llm_sanitize_response_guardrail( + &scope.uuid, + "lock_scope_llm_sanitize_response", + 2, + Arc::new(move |response| { + record_middleware_callback(&tracked, "llm_sanitize_response_scope"); + assert_middleware_callback_locks_are_free(); + response + }), + ) + .unwrap(); + + let tracked = callbacks.clone(); + let func: LlmExecutionNextFn = Arc::new(move |_| { + record_middleware_callback(&tracked, "llm_func"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { Ok(json!({"ok": true})) }) + }); + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("llm") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"messages": []}), + }) + .func(func) + .build(), + ) + .await + .unwrap(); + assert_eq!(response["ok"], true); + + let tracked = callbacks.clone(); + let stream_func: LlmStreamExecutionNextFn = Arc::new(move |_| { + record_middleware_callback(&tracked, "llm_stream_func"); + assert_middleware_callback_locks_are_free(); + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": true}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + let tracked = callbacks.clone(); + let collector = Box::new(move |_| { + record_middleware_callback(&tracked, "llm_collector"); + assert_middleware_callback_locks_are_free(); + Ok(()) + }); + let tracked = callbacks.clone(); + let finalizer = Box::new(move || { + record_middleware_callback(&tracked, "llm_finalizer"); + assert_middleware_callback_locks_are_free(); + json!({"stream": true}) + }); + let mut stream = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("llm-stream") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"messages": []}), + }) + .func(stream_func) + .collector(collector) + .finalizer(finalizer) + .build(), + ) + .await + .unwrap(); + while let Some(chunk) = stream.next().await { + chunk.unwrap(); + } + assert_middleware_callback_labels( + &callbacks, + &[ + "llm_conditional_global", + "llm_conditional_global", + "llm_conditional_scope", + "llm_conditional_scope", + "llm_request_global", + "llm_request_global", + "llm_request_scope", + "llm_request_scope", + "llm_sanitize_request_global", + "llm_sanitize_request_global", + "llm_sanitize_request_scope", + "llm_sanitize_request_scope", + "llm_execution_global", + "llm_execution_scope", + "llm_func", + "llm_stream_execution_global", + "llm_stream_execution_scope", + "llm_stream_func", + "llm_collector", + "llm_finalizer", + "llm_sanitize_response_global", + "llm_sanitize_response_global", + "llm_sanitize_response_scope", + "llm_sanitize_response_scope", + ], + ); + + deregister_llm_conditional_execution_guardrail("lock_global_llm_conditional").unwrap(); + deregister_llm_request_intercept("lock_global_llm_request").unwrap(); + deregister_llm_sanitize_request_guardrail("lock_global_llm_sanitize_request").unwrap(); + deregister_llm_execution_intercept("lock_global_llm_execution").unwrap(); + deregister_llm_stream_execution_intercept("lock_global_llm_stream_execution").unwrap(); + deregister_llm_sanitize_response_guardrail("lock_global_llm_sanitize_response").unwrap(); + pop_scope( + nemo_relay::api::scope::PopScopeParams::builder() + .handle_uuid(&scope.uuid) + .build(), + ) + .unwrap(); +} + // ========================================================================= // Full Pipeline Integration Test // ========================================================================= diff --git a/crates/core/tests/unit/context_tests.rs b/crates/core/tests/unit/context_tests.rs index f402de4f..7b15dda3 100644 --- a/crates/core/tests/unit/context_tests.rs +++ b/crates/core/tests/unit/context_tests.rs @@ -279,7 +279,9 @@ fn context_state_supports_extensions_events_and_builders() { headers: Map::new(), content: json!({"messages": []}), }; - let sanitized = state.llm_sanitize_request_chain(request.clone(), &[]); + let entries = state.llm_sanitize_request_entries(&[]); + let sanitized = + NemoRelayContextState::llm_sanitize_request_snapshot_chain(request.clone(), &entries); assert!(sanitized.headers.is_empty()); let events = Arc::new(Mutex::new(Vec::::new()));