From 4dee245e7e9cd946c70982e4f8896c064da2efea Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Fri, 5 Jun 2026 08:33:04 -0400 Subject: [PATCH] Add Monte Carlo sampling mode to SaliencyEngine n_samples>1 switches from logprob-drop scoring to sampling: each masked context is run n_samples times at sample_temperature and importance = P(original | full) - P(original | masked). Produces a genuine probability distribution instead of near-binary scores when the model is confident at temperature=0. - 11 new tests covering sampling mode, call counts, and edge cases - Demo updated to use n_samples=5 with a loan review scenario --- examples/demo_saliency.py | 124 +++++++++++---- src/motive/saliency.py | 112 ++++++++++--- tests/test_saliency.py | 327 ++++++++++++++++++++++++++++---------- 3 files changed, 417 insertions(+), 146 deletions(-) diff --git a/examples/demo_saliency.py b/examples/demo_saliency.py index a0b3dc6..71c1e3e 100644 --- a/examples/demo_saliency.py +++ b/examples/demo_saliency.py @@ -1,4 +1,9 @@ -"""Demo: hierarchical occlusion saliency on a support-ticket routing decision.""" +"""Demo: hierarchical occlusion saliency on a loan application review decision. + +Designed to show a realistic importance distribution across multiple segments. +The decision (approve / flag / request_docs) depends on partial signals from +several sources, so masking each segment causes a different drop in confidence. +""" import asyncio import os @@ -12,70 +17,116 @@ API_KEY = os.environ["VECTOR_API_KEY"] MODEL = "Qwen3-Coder-Next" -# --- Context segments (what the agent sees when deciding which tool to call) --- +# --- Context segments --- +# Each segment contributes partial evidence. No single segment is decisive, +# so masking them produces a spread of importance scores rather than 0/1. SEGMENTS = [ Segment( id="system", - content="You are a support agent. Escalate tickets that are urgent or have been unresolved for more than 48 hours.", + content=( + "You are a loan underwriting assistant. " + "Approve applications that clearly meet all criteria. " + "Flag for manual review if any criterion is borderline. " + "Request more documents if required information is missing." + ), label="System prompt", level=SegmentLevel.DOCUMENT, ), Segment( - id="user_msg", - content="My payment has been failing for 3 days and my account is now locked. I need this fixed urgently.", - label="User message", + id="application", + content=( + "Applicant: Jordan Lee. " + "Requested loan: $42,000. " + "Stated annual income: $68,000. " + "Credit score: 694. " + "Employment: salaried, 2.5 years at current employer." + ), + label="Loan application", + level=SegmentLevel.DOCUMENT, + ), + Segment( + id="credit_policy", + content=( + "Credit score policy: scores above 720 qualify for standard approval. " + "Scores between 660 and 720 require manual review. " + "Scores below 660 are declined automatically." + ), + label="Policy: credit score", level=SegmentLevel.DOCUMENT, ), Segment( - id="doc_1", - content="Escalation policy: accounts locked for more than 48 hours require immediate human review. Do not attempt automated resolution.", - label="Retrieved doc: escalation policy", + id="income_policy", + content=( + "Debt-to-income policy: the requested loan must not exceed 65% of annual income. " + "Borderline cases (60-65%) require a supervisor sign-off." + ), + label="Policy: income ratio", level=SegmentLevel.DOCUMENT, ), Segment( - id="doc_2", - content="Payment retry guide: ask the user to clear browser cache and retry. Most payment failures resolve within 24 hours.", - label="Retrieved doc: payment retry guide", + id="employment_policy", + content=( + "Employment policy: applicants must have at least 12 months of continuous employment. " + "Less than 24 months at the current employer is considered borderline." + ), + label="Policy: employment", + level=SegmentLevel.DOCUMENT, + ), + Segment( + id="prior_history", + content=( + "Credit history: no prior defaults. " + "One missed payment 18 months ago, now resolved. " + "No open collections or bankruptcies." + ), + label="Prior credit history", level=SegmentLevel.DOCUMENT, ), ] -# --- Full messages for the decision call --- +# --- Messages --- MESSAGES = [ {"role": "system", "content": SEGMENTS[0].content}, - {"role": "user", "content": SEGMENTS[1].content}, { "role": "user", - "content": (f"Context documents:\n\n[Doc 1] {SEGMENTS[2].content}\n\n[Doc 2] {SEGMENTS[3].content}"), + "content": ( + f"Please review this loan application and decide on the next action.\n\n" + f"Application:\n{SEGMENTS[1].content}\n\n" + f"Relevant policies and history:\n" + f"[Credit policy] {SEGMENTS[2].content}\n" + f"[Income policy] {SEGMENTS[3].content}\n" + f"[Employment policy] {SEGMENTS[4].content}\n" + f"[Credit history] {SEGMENTS[5].content}" + ), }, ] -# --- Tool definitions --- +# --- Tools --- TOOLS = [ { "type": "function", "function": { - "name": "escalate_to_human", - "description": "Escalate the ticket to a human support agent.", + "name": "approve_loan", + "description": "Approve the loan application. All criteria are clearly met.", "parameters": {"type": "object", "properties": {}, "required": []}, }, }, { "type": "function", "function": { - "name": "send_retry_instructions", - "description": "Send automated payment retry instructions to the user.", + "name": "flag_for_manual_review", + "description": "Flag the application for manual review by a senior underwriter. Use when any criterion is borderline.", "parameters": {"type": "object", "properties": {}, "required": []}, }, }, { "type": "function", "function": { - "name": "auto_resolve", - "description": "Mark the ticket as resolved automatically.", + "name": "request_additional_documents", + "description": "Ask the applicant for missing or incomplete documentation before proceeding.", "parameters": {"type": "object", "properties": {}, "required": []}, }, }, @@ -85,10 +136,14 @@ async def main() -> None: """Run the saliency demo.""" client = AsyncOpenAI(base_url=PROXY_URL, api_key=API_KEY) - engine = SaliencyEngine(client=client, model=MODEL, top_k=2) + # n_samples=5: run each masked context 5 times at temperature=0.7 and + # measure P(original_decision | masked). Produces a genuine distribution + # rather than near-binary logprob scores. + engine = SaliencyEngine(client=client, model=MODEL, top_k=2, n_samples=5) print(f"Model: {MODEL}") - print(f"Running hierarchical occlusion on {len(SEGMENTS)} segments...\n") + print(f"Segments: {len(SEGMENTS)}") + print("Running hierarchical occlusion (n_samples=5, temperature=0.7)...\n") result = await engine.explain_async( messages=MESSAGES, # type: ignore[arg-type] @@ -96,23 +151,24 @@ async def main() -> None: tools=TOOLS, # type: ignore[arg-type] ) - print(f"Original decision: {result.original_decision}\n") - print("=== Pass 1: Segment-level importance ===") + print(f"Decision: {result.original_decision}\n") + + print("=== Pass 1: segment-level importance ===") doc_scores = [s for s in result.top if "__s" not in s.segment_id] for score in doc_scores: - bar = "█" * int(score.importance * 20) - changed = " ← flipped" if score.decision_changed else "" - print(f" {score.label:<40} {score.importance:.3f} {bar}{changed}") + bar = "█" * int(score.importance * 30) + changed = " [flipped]" if score.decision_changed else "" + print(f" {score.label:<38} {score.importance:.3f} {bar}{changed}") - print("\n=== Pass 2: Sentence-level (top segments drilled in) ===") + print("\n=== Pass 2: sentence-level (top segments) ===") sent_scores = [s for s in result.top if "__s" in s.segment_id] if sent_scores: for score in sent_scores: - bar = "█" * int(score.importance * 20) - changed = " ← flipped" if score.decision_changed else "" - print(f" {score.label[:60]:<62} {score.importance:.3f} {bar}{changed}") + bar = "█" * int(score.importance * 30) + changed = " [flipped]" if score.decision_changed else "" + print(f" {score.label[:55]:<57} {score.importance:.3f} {bar}{changed}") else: - print(" (no multi-sentence segments found)") + print(" (top segments were single sentences)") print(f"\n{result.summary()}") diff --git a/src/motive/saliency.py b/src/motive/saliency.py index 937f501..0eb2d1d 100644 --- a/src/motive/saliency.py +++ b/src/motive/saliency.py @@ -29,7 +29,7 @@ def _extract_tool_name(response: Any) -> str: def _logprob_for_tool(response: Any, tool_name: str) -> float | None: """Sum logprobs over the tokens that spell out `tool_name` in the response. - Tool names are multi-token (e.g. "escalate_to_human" → ["escal","ate","_to","_human"]), + Tool names are multi-token (e.g. "escalate_to_human" -> ["escal","ate","_to","_human"]), so we reconstruct the full generated string, locate the tool name span, and sum the logprobs of every token that overlaps with it. """ @@ -107,9 +107,20 @@ class SaliencyEngine: Pass 1 masks at segment (document/message) level. Pass 2 drills into the top_k segments at sentence level. - Importance is the drop in log-probability of the original tool choice - when a segment is masked. Falls back to binary (flipped = 1.0, - unchanged = 0.0) when logprobs are unavailable. + Scoring modes + ------------- + n_samples=1 (default): + Single deterministic run (temperature=0). Importance = drop in + log-probability of the original tool when the segment is masked. + Fast but can produce near-binary scores when the model is very confident. + + n_samples>1: + Monte Carlo sampling (temperature=sample_temperature). Runs each + masked context n_samples times and measures how often the original + decision survives. Importance = P(original | full) - P(original | masked). + Slower but produces a genuine probability distribution (e.g. 0.0, 0.2, + 0.4, 0.6, 0.8, 1.0 with n_samples=5). Use this when logprob drops are + too small to distinguish importance levels. """ def __init__( @@ -117,10 +128,14 @@ def __init__( client: AsyncOpenAI, model: str, top_k: int = _TOP_K_DEFAULT, + n_samples: int = 1, + sample_temperature: float = 0.7, ) -> None: self.client = client self.model = model self.top_k = top_k + self.n_samples = n_samples + self.sample_temperature = sample_temperature def explain( self, @@ -139,16 +154,24 @@ async def explain_async( tools: list[ChatCompletionToolParam], tool_choice: ToolChoice = "auto", ) -> SaliencyResult: - """Async entry point.""" - original_resp = await self._call(messages, tools, tool_choice, logprobs=True) - original_tool = _extract_tool_name(original_resp) - original_lp = _logprob_for_tool(original_resp, original_tool) + """Run explanation asynchronously.""" + if self.n_samples > 1: + original_prob = await self._sample_probability(messages, tools, tool_choice) + original_tool = await self._get_tool(messages, tools, tool_choice) + original_lp = None + else: + original_resp = await self._call(messages, tools, tool_choice, logprobs=True) + original_tool = _extract_tool_name(original_resp) + original_lp = _logprob_for_tool(original_resp, original_tool) + original_prob = 1.0 # Pass 1: segment level - pass1 = await self._occlusion_pass(messages, segments, tools, tool_choice, original_tool, original_lp) + pass1 = await self._occlusion_pass( + messages, segments, tools, tool_choice, original_tool, original_lp, original_prob + ) pass1 = _normalise(pass1) - # Pass 2: sentence level on top_k segments that actually had non-zero importance + # Pass 2: sentence level on top_k segments with non-zero importance top_ids = { s.segment_id for s in sorted(pass1, key=lambda s: s.importance, reverse=True)[: self.top_k] @@ -171,7 +194,7 @@ async def explain_async( for i, s in enumerate(sentences) ] sub_scores = await self._occlusion_pass( - messages, sub_segments, tools, tool_choice, original_tool, original_lp + messages, sub_segments, tools, tool_choice, original_tool, original_lp, original_prob ) pass2.extend(_normalise(sub_scores)) @@ -189,9 +212,10 @@ async def _occlusion_pass( tool_choice: ToolChoice, original_tool: str, original_lp: float | None, + original_prob: float, ) -> list[SaliencyScore]: tasks = [ - self._score_segment(messages, segments, i, tools, tool_choice, original_tool, original_lp) + self._score_segment(messages, segments, i, tools, tool_choice, original_tool, original_lp, original_prob) for i in range(len(segments)) ] return list(await asyncio.gather(*tasks)) @@ -205,20 +229,26 @@ async def _score_segment( tool_choice: ToolChoice, original_tool: str, original_lp: float | None, + original_prob: float, ) -> SaliencyScore: masked_msgs = _mask_messages(messages, segments, mask_idx) - resp = await self._call(masked_msgs, tools, tool_choice, logprobs=True) - masked_tool = _extract_tool_name(resp) - - if original_lp is not None: - masked_lp = _logprob_for_tool(resp, original_tool) - importance = ( - max(0.0, original_lp - masked_lp) - if masked_lp is not None - else (1.0 if masked_tool != original_tool else 0.0) - ) + + if self.n_samples > 1: + masked_prob = await self._sample_probability(masked_msgs, tools, tool_choice, target_tool=original_tool) + importance = max(0.0, original_prob - masked_prob) + masked_tool = original_tool if masked_prob >= 0.5 else "other" else: - importance = 1.0 if masked_tool != original_tool else 0.0 + resp = await self._call(masked_msgs, tools, tool_choice, logprobs=True) + masked_tool = _extract_tool_name(resp) + if original_lp is not None: + masked_lp = _logprob_for_tool(resp, original_tool) + importance = ( + max(0.0, original_lp - masked_lp) + if masked_lp is not None + else (1.0 if masked_tool != original_tool else 0.0) + ) + else: + importance = 1.0 if masked_tool != original_tool else 0.0 seg = segments[mask_idx] return SaliencyScore( @@ -230,12 +260,46 @@ async def _score_segment( masked_decision=masked_tool, ) + async def _sample_probability( + self, + messages: list[ChatCompletionMessageParam], + tools: list[ChatCompletionToolParam], + tool_choice: ToolChoice, + target_tool: str | None = None, + ) -> float: + """Estimate P(target_tool | messages) via n_samples draws at sample_temperature.""" + resps = await asyncio.gather( + *[ + self._call(messages, tools, tool_choice, temperature=self.sample_temperature) + for _ in range(self.n_samples) + ] + ) + if target_tool is None: + # First call: use the majority decision as target + target_tool = max( + {_extract_tool_name(r) for r in resps}, + key=lambda t: sum(1 for r in resps if _extract_tool_name(r) == t), + ) + matches = sum(1 for r in resps if _extract_tool_name(r) == target_tool) + return matches / self.n_samples + + async def _get_tool( + self, + messages: list[ChatCompletionMessageParam], + tools: list[ChatCompletionToolParam], + tool_choice: ToolChoice, + ) -> str: + """Get the deterministic tool choice (temperature=0) for the original context.""" + resp = await self._call(messages, tools, tool_choice) + return _extract_tool_name(resp) + async def _call( self, messages: list[ChatCompletionMessageParam], tools: list[ChatCompletionToolParam], tool_choice: ToolChoice, logprobs: bool = False, + temperature: float = 0.0, ) -> Any: return await self.client.chat.completions.create( model=self.model, @@ -244,5 +308,5 @@ async def _call( tool_choice=tool_choice, logprobs=logprobs, top_logprobs=5 if logprobs else None, - temperature=0, + temperature=temperature, ) diff --git a/tests/test_saliency.py b/tests/test_saliency.py index 4e9655b..4d7ad68 100644 --- a/tests/test_saliency.py +++ b/tests/test_saliency.py @@ -69,6 +69,51 @@ def _text_response(text: str) -> MagicMock: return resp +def _make_engine(n_samples: int = 1, top_k: int = 2) -> tuple[SaliencyEngine, AsyncMock]: + """Return an engine wired to a mock OpenAI client.""" + client = MagicMock() + client.chat = MagicMock() + client.chat.completions = MagicMock() + client.chat.completions.create = AsyncMock() + return SaliencyEngine(client=client, model="test-model", top_k=top_k, n_samples=n_samples), client + + +def _sample_segments() -> list[Segment]: + return [ + Segment(id="system", content="You are a support agent. Escalate urgent issues.", label="System prompt"), + Segment(id="user_msg", content="My account is locked for 3 days.", label="User message"), + Segment(id="doc_1", content="Locked accounts require human review.", label="Escalation policy"), + ] + + +def _sample_messages() -> list[dict]: + return [ + {"role": "system", "content": "You are a support agent. Escalate urgent issues."}, + {"role": "user", "content": "My account is locked for 3 days. Locked accounts require human review."}, + ] + + +def _sample_tools() -> list[dict]: + return [ + { + "type": "function", + "function": { + "name": "escalate_to_human", + "description": "Escalate.", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + { + "type": "function", + "function": { + "name": "send_reminder", + "description": "Send reminder.", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + ] + + # --------------------------------------------------------------------------- # _extract_tool_name # --------------------------------------------------------------------------- @@ -105,7 +150,6 @@ def test_single_token_match(self) -> None: assert result == pytest.approx(-0.5) def test_multi_token_match_sums_logprobs(self) -> None: - # "escalate_to_human" split as "escal" + "ate" + "_to" + "_human" logprobs = [ {"token": " None: resp = _tool_response("tool_a", logprobs=[]) assert _logprob_for_tool(resp, "tool_a") is None + def test_partial_overlap_tokens_included(self) -> None: + # Token spans: "too" covers start of "tool", "_x" covers end + logprobs = [ + {"token": "too", "logprob": -0.3}, + {"token": "l_x", "logprob": -0.4}, + ] + resp = _tool_response("tool_x", logprobs) + result = _logprob_for_tool(resp, "tool_x") + assert result == pytest.approx(-0.3 + -0.4) + # --------------------------------------------------------------------------- # _mask_messages @@ -186,6 +240,15 @@ def test_returns_new_list_not_mutating_original(self) -> None: _mask_messages(messages, segments, mask_idx=0) assert messages[0]["content"] == original_content + def test_masks_only_first_occurrence(self) -> None: + segments = [Segment(id="s", content="repeat")] + messages = [{"role": "user", "content": "repeat and repeat again"}] + masked = _mask_messages(messages, segments, mask_idx=0) + content = masked[0]["content"] + assert "[CONTENT REDACTED]" in content + assert content.count("[CONTENT REDACTED]") == 1 + assert "repeat again" in content + # --------------------------------------------------------------------------- # _split_sentences @@ -218,6 +281,10 @@ def test_filters_whitespace_only_parts(self) -> None: result = _split_sentences(" First. Second. ") assert all(s.strip() for s in result) + def test_three_sentences(self) -> None: + result = _split_sentences("One. Two. Three.") + assert len(result) == 3 + # --------------------------------------------------------------------------- # _normalise @@ -259,62 +326,23 @@ def test_preserves_segment_ids(self) -> None: normalised = _normalise(scores) assert {s.segment_id for s in normalised} == {"x", "y"} + def test_single_nonzero_score_becomes_one(self) -> None: + scores = [self._score("a", 0.0), self._score("b", 3.7)] + normalised = _normalise(scores) + by_id = {s.segment_id: s.importance for s in normalised} + assert by_id["b"] == pytest.approx(1.0) + assert by_id["a"] == pytest.approx(0.0) + # --------------------------------------------------------------------------- -# SaliencyEngine — async, mocked client +# SaliencyEngine (n_samples=1, logprob mode) # --------------------------------------------------------------------------- -def _make_engine() -> tuple[SaliencyEngine, AsyncMock]: - """Return an engine wired to a mock OpenAI client.""" - client = MagicMock() - client.chat = MagicMock() - client.chat.completions = MagicMock() - client.chat.completions.create = AsyncMock() - return SaliencyEngine(client=client, model="test-model", top_k=2), client - - -def _sample_segments() -> list[Segment]: - return [ - Segment(id="system", content="You are a support agent. Escalate urgent issues.", label="System prompt"), - Segment(id="user_msg", content="My account is locked for 3 days.", label="User message"), - Segment(id="doc_1", content="Locked accounts require human review.", label="Escalation policy"), - ] - - -def _sample_messages() -> list[dict]: - return [ - {"role": "system", "content": "You are a support agent. Escalate urgent issues."}, - {"role": "user", "content": "My account is locked for 3 days. Locked accounts require human review."}, - ] - - -def _sample_tools() -> list[dict]: - return [ - { - "type": "function", - "function": { - "name": "escalate_to_human", - "description": "Escalate.", - "parameters": {"type": "object", "properties": {}, "required": []}, - }, - }, - { - "type": "function", - "function": { - "name": "send_reminder", - "description": "Send reminder.", - "parameters": {"type": "object", "properties": {}, "required": []}, - }, - }, - ] - - @pytest.mark.asyncio class TestSaliencyEngineAsync: async def test_returns_saliency_result_with_correct_decision(self) -> None: engine, client = _make_engine() - # All calls return the same tool — no flips client.chat.completions.create.return_value = _tool_response("escalate_to_human") result = await engine.explain_async( @@ -328,22 +356,17 @@ async def test_returns_saliency_result_with_correct_decision(self) -> None: async def test_flipped_segment_gets_high_importance(self) -> None: engine, client = _make_engine() - segments = _sample_segments() - responses = [] - # Original call - responses.append(_tool_response("escalate_to_human")) - # Masking system → same - responses.append(_tool_response("escalate_to_human")) - # Masking user_msg → flips - responses.append(_tool_response("send_reminder")) - # Masking doc_1 → same - responses.append(_tool_response("escalate_to_human")) - + responses = [ + _tool_response("escalate_to_human"), # original + _tool_response("escalate_to_human"), # mask system → same + _tool_response("send_reminder"), # mask user_msg → flips + _tool_response("escalate_to_human"), # mask doc_1 → same + ] client.chat.completions.create.side_effect = responses result = await engine.explain_async( messages=_sample_messages(), # type: ignore[arg-type] - segments=segments, + segments=_sample_segments(), tools=_sample_tools(), # type: ignore[arg-type] ) @@ -361,15 +384,10 @@ async def test_logprob_drop_used_as_importance_when_available(self) -> None: ] messages = [{"role": "user", "content": "important context irrelevant context"}] - lp_tokens = [ - {"token": "tool", "logprob": -0.01}, - {"token": "_a", "logprob": -0.01}, - ] - # Original: logprob of "tool_a" = -0.02 - original_resp = _tool_response("tool_a", lp_tokens) - # Mask seg_a: logprob drops to -1.5 → importance = -0.02 - (-1.5) = 1.48 + original_resp = _tool_response( + "tool_a", [{"token": "tool", "logprob": -0.01}, {"token": "_a", "logprob": -0.01}] + ) masked_a = _tool_response("tool_a", [{"token": "tool", "logprob": -0.8}, {"token": "_a", "logprob": -0.7}]) - # Mask seg_b: logprob barely changes → low importance masked_b = _tool_response("tool_a", [{"token": "tool", "logprob": -0.01}, {"token": "_a", "logprob": -0.02}]) client.chat.completions.create.side_effect = [original_resp, masked_a, masked_b] @@ -384,9 +402,26 @@ async def test_logprob_drop_used_as_importance_when_available(self) -> None: by_id = {s.segment_id: s for s in pass1} assert by_id["seg_a"].importance > by_id["seg_b"].importance + async def test_importance_clipped_at_zero_when_masked_logprob_higher(self) -> None: + engine, client = _make_engine() + segments = [Segment(id="seg_a", content="context", label="A")] + messages = [{"role": "user", "content": "context"}] + + # Original logprob lower than masked (masked is more confident) — drop should clip to 0 + original_resp = _tool_response("tool_a", [{"token": "tool_a", "logprob": -1.0}]) + masked_resp = _tool_response("tool_a", [{"token": "tool_a", "logprob": -0.1}]) + client.chat.completions.create.side_effect = [original_resp, masked_resp] + + result = await engine.explain_async( + messages=messages, # type: ignore[arg-type] + segments=segments, + tools=_sample_tools(), # type: ignore[arg-type] + ) + + assert result.scores[0].importance == pytest.approx(0.0) + async def test_no_pass2_when_all_segments_zero_importance(self) -> None: engine, client = _make_engine() - # Every call returns the same tool — no flips, no logprobs → all zero client.chat.completions.create.return_value = _tool_response("escalate_to_human") result = await engine.explain_async( @@ -395,8 +430,7 @@ async def test_no_pass2_when_all_segments_zero_importance(self) -> None: tools=_sample_tools(), # type: ignore[arg-type] ) - sentence_scores = [s for s in result.scores if "__s" in s.segment_id] - assert sentence_scores == [] + assert [s for s in result.scores if "__s" in s.segment_id] == [] async def test_pass2_only_drills_into_nonzero_segments(self) -> None: engine, client = _make_engine() @@ -408,17 +442,13 @@ async def test_pass2_only_drills_into_nonzero_segments(self) -> None: {"role": "system", "content": "You are an agent."}, {"role": "user", "content": "Urgent issue. Account locked."}, ] - - responses = [] - # Original - responses.append(_tool_response("escalate_to_human")) - # Pass 1: mask system → same; mask user_msg → flips - responses.append(_tool_response("escalate_to_human")) - responses.append(_tool_response("send_reminder")) - # Pass 2: two sentences in user_msg - responses.append(_tool_response("escalate_to_human")) # mask "Urgent issue." - responses.append(_tool_response("send_reminder")) # mask "Account locked." - + responses = [ + _tool_response("escalate_to_human"), # original + _tool_response("escalate_to_human"), # mask system → same + _tool_response("send_reminder"), # mask user_msg → flips + _tool_response("escalate_to_human"), # pass2: mask "Urgent issue." + _tool_response("send_reminder"), # pass2: mask "Account locked." + ] client.chat.completions.create.side_effect = responses result = await engine.explain_async( @@ -428,23 +458,20 @@ async def test_pass2_only_drills_into_nonzero_segments(self) -> None: ) sentence_scores = [s for s in result.scores if "__s" in s.segment_id] - # Only user_msg should be drilled (system had 0 importance) assert all(s.segment_id.startswith("user_msg") for s in sentence_scores) assert len(sentence_scores) == 2 - async def test_call_count_is_1_plus_n_segments_for_pass1(self) -> None: + async def test_call_count_is_1_plus_n_segments_minimum(self) -> None: engine, client = _make_engine() - segments = _sample_segments() # 3 segments client.chat.completions.create.return_value = _tool_response("escalate_to_human") await engine.explain_async( messages=_sample_messages(), # type: ignore[arg-type] - segments=segments, + segments=_sample_segments(), # 3 segments tools=_sample_tools(), # type: ignore[arg-type] ) - # 1 original + 3 masked = 4 calls minimum (plus any pass2 calls) - assert client.chat.completions.create.call_count >= 4 + assert client.chat.completions.create.call_count >= 4 # 1 original + 3 masked class TestSaliencyEngineSync: @@ -460,3 +487,127 @@ def test_explain_sync_returns_result(self) -> None: assert result.original_decision == "escalate_to_human" assert len(result.scores) > 0 + + +# --------------------------------------------------------------------------- +# SaliencyEngine (n_samples>1, sampling mode) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestSaliencyEngineSamplingMode: + async def test_sampling_mode_uses_probability_drop(self) -> None: + engine, client = _make_engine(n_samples=4) + segments = [ + Segment(id="seg_a", content="key context", label="A"), + Segment(id="seg_b", content="irrelevant", label="B"), + ] + messages = [{"role": "user", "content": "key context irrelevant"}] + + # Original tool determined via _get_tool (1 call at temp=0) + # Then _sample_probability for original: 4 calls (all return tool_a) → P=1.0 + # Mask seg_a: 4 calls → 1 returns tool_a → P=0.25 → importance=0.75 + # Mask seg_b: 4 calls → all return tool_a → P=1.0 → importance=0.0 + responses = ( + [_tool_response("tool_a")] # _get_tool + + [_tool_response("tool_a")] * 4 # original probability sampling + + [ + _tool_response("tool_a"), + _tool_response("send_reminder"), + _tool_response("send_reminder"), + _tool_response("send_reminder"), + ] # mask seg_a + + [_tool_response("tool_a")] * 4 # mask seg_b + ) + client.chat.completions.create.side_effect = responses + + result = await engine.explain_async( + messages=messages, # type: ignore[arg-type] + segments=segments, + tools=_sample_tools(), # type: ignore[arg-type] + ) + + pass1 = [s for s in result.scores if "__s" not in s.segment_id] + by_id = {s.segment_id: s for s in pass1} + assert by_id["seg_a"].importance > by_id["seg_b"].importance + + async def test_sampling_mode_importance_proportional_to_prob_drop(self) -> None: + engine, client = _make_engine(n_samples=4) + segments = [Segment(id="seg_a", content="context", label="A")] + messages = [{"role": "user", "content": "context"}] + + # Original: P=1.0 (4/4 match) + # Mask seg_a: P=0.5 (2/4 match) → raw importance = 0.5 + responses = ( + [_tool_response("tool_a")] # _get_tool + + [_tool_response("tool_a")] * 4 # original probability + + [ + _tool_response("tool_a"), + _tool_response("tool_a"), + _tool_response("send_reminder"), + _tool_response("send_reminder"), + ] # mask + ) + client.chat.completions.create.side_effect = responses + + result = await engine.explain_async( + messages=messages, # type: ignore[arg-type] + segments=segments, + tools=_sample_tools(), # type: ignore[arg-type] + ) + + # Only one segment, normalises to 1.0 regardless of raw value + assert result.scores[0].importance == pytest.approx(1.0) + + async def test_sampling_mode_decision_unchanged_when_prob_majority_same(self) -> None: + engine, client = _make_engine(n_samples=4) + segments = [Segment(id="seg_a", content="context", label="A")] + messages = [{"role": "user", "content": "context"}] + + # Mask: 3/4 still return tool_a → majority same → decision_changed=False + responses = ( + [_tool_response("tool_a")] # _get_tool + + [_tool_response("tool_a")] * 4 # original probability + + [ + _tool_response("tool_a"), + _tool_response("tool_a"), + _tool_response("tool_a"), + _tool_response("send_reminder"), + ] + ) + client.chat.completions.create.side_effect = responses + + result = await engine.explain_async( + messages=messages, # type: ignore[arg-type] + segments=segments, + tools=_sample_tools(), # type: ignore[arg-type] + ) + + assert result.scores[0].decision_changed is False + + async def test_sampling_mode_call_count(self) -> None: + # n_samples=3, 2 segments: + # 1 _get_tool + 3 original_prob + (3 mask_seg_a + 3 mask_seg_b) = 10 calls + engine, client = _make_engine(n_samples=3) + segments = [ + Segment(id="seg_a", content="a", label="A"), + Segment(id="seg_b", content="b", label="B"), + ] + messages = [{"role": "user", "content": "a b"}] + client.chat.completions.create.return_value = _tool_response("tool_a") + + await engine.explain_async( + messages=messages, # type: ignore[arg-type] + segments=segments, + tools=_sample_tools(), # type: ignore[arg-type] + ) + + assert client.chat.completions.create.call_count == 10 + + async def test_n_samples_default_is_1(self) -> None: + engine, _ = _make_engine() + assert engine.n_samples == 1 + + async def test_sample_temperature_default(self) -> None: + engine, _ = _make_engine() + assert engine.sample_temperature == pytest.approx(0.7)