Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 90 additions & 34 deletions examples/demo_saliency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""Demo: hierarchical occlusion saliency on a support-ticket routing decision."""
"""Demo: hierarchical occlusion saliency on a loan application review decision.

Designed to show a realistic importance distribution across multiple segments.
The decision (approve / flag / request_docs) depends on partial signals from
several sources, so masking each segment causes a different drop in confidence.
"""

import asyncio
import os
Expand All @@ -12,70 +17,116 @@
API_KEY = os.environ["VECTOR_API_KEY"]
MODEL = "Qwen3-Coder-Next"

# --- Context segments (what the agent sees when deciding which tool to call) ---
# --- Context segments ---
# Each segment contributes partial evidence. No single segment is decisive,
# so masking them produces a spread of importance scores rather than 0/1.

SEGMENTS = [
Segment(
id="system",
content="You are a support agent. Escalate tickets that are urgent or have been unresolved for more than 48 hours.",
content=(
"You are a loan underwriting assistant. "
"Approve applications that clearly meet all criteria. "
"Flag for manual review if any criterion is borderline. "
"Request more documents if required information is missing."
),
label="System prompt",
level=SegmentLevel.DOCUMENT,
),
Segment(
id="user_msg",
content="My payment has been failing for 3 days and my account is now locked. I need this fixed urgently.",
label="User message",
id="application",
content=(
"Applicant: Jordan Lee. "
"Requested loan: $42,000. "
"Stated annual income: $68,000. "
"Credit score: 694. "
"Employment: salaried, 2.5 years at current employer."
),
label="Loan application",
level=SegmentLevel.DOCUMENT,
),
Segment(
id="credit_policy",
content=(
"Credit score policy: scores above 720 qualify for standard approval. "
"Scores between 660 and 720 require manual review. "
"Scores below 660 are declined automatically."
),
label="Policy: credit score",
level=SegmentLevel.DOCUMENT,
),
Segment(
id="doc_1",
content="Escalation policy: accounts locked for more than 48 hours require immediate human review. Do not attempt automated resolution.",
label="Retrieved doc: escalation policy",
id="income_policy",
content=(
"Debt-to-income policy: the requested loan must not exceed 65% of annual income. "
"Borderline cases (60-65%) require a supervisor sign-off."
),
label="Policy: income ratio",
level=SegmentLevel.DOCUMENT,
),
Segment(
id="doc_2",
content="Payment retry guide: ask the user to clear browser cache and retry. Most payment failures resolve within 24 hours.",
label="Retrieved doc: payment retry guide",
id="employment_policy",
content=(
"Employment policy: applicants must have at least 12 months of continuous employment. "
"Less than 24 months at the current employer is considered borderline."
),
label="Policy: employment",
level=SegmentLevel.DOCUMENT,
),
Segment(
id="prior_history",
content=(
"Credit history: no prior defaults. "
"One missed payment 18 months ago, now resolved. "
"No open collections or bankruptcies."
),
label="Prior credit history",
level=SegmentLevel.DOCUMENT,
),
]

# --- Full messages for the decision call ---
# --- Messages ---

MESSAGES = [
{"role": "system", "content": SEGMENTS[0].content},
{"role": "user", "content": SEGMENTS[1].content},
{
"role": "user",
"content": (f"Context documents:\n\n[Doc 1] {SEGMENTS[2].content}\n\n[Doc 2] {SEGMENTS[3].content}"),
"content": (
f"Please review this loan application and decide on the next action.\n\n"
f"Application:\n{SEGMENTS[1].content}\n\n"
f"Relevant policies and history:\n"
f"[Credit policy] {SEGMENTS[2].content}\n"
f"[Income policy] {SEGMENTS[3].content}\n"
f"[Employment policy] {SEGMENTS[4].content}\n"
f"[Credit history] {SEGMENTS[5].content}"
),
},
]

# --- Tool definitions ---
# --- Tools ---

TOOLS = [
{
"type": "function",
"function": {
"name": "escalate_to_human",
"description": "Escalate the ticket to a human support agent.",
"name": "approve_loan",
"description": "Approve the loan application. All criteria are clearly met.",
"parameters": {"type": "object", "properties": {}, "required": []},
},
},
{
"type": "function",
"function": {
"name": "send_retry_instructions",
"description": "Send automated payment retry instructions to the user.",
"name": "flag_for_manual_review",
"description": "Flag the application for manual review by a senior underwriter. Use when any criterion is borderline.",
"parameters": {"type": "object", "properties": {}, "required": []},
},
},
{
"type": "function",
"function": {
"name": "auto_resolve",
"description": "Mark the ticket as resolved automatically.",
"name": "request_additional_documents",
"description": "Ask the applicant for missing or incomplete documentation before proceeding.",
"parameters": {"type": "object", "properties": {}, "required": []},
},
},
Expand All @@ -85,34 +136,39 @@
async def main() -> None:
"""Run the saliency demo."""
client = AsyncOpenAI(base_url=PROXY_URL, api_key=API_KEY)
engine = SaliencyEngine(client=client, model=MODEL, top_k=2)
# n_samples=5: run each masked context 5 times at temperature=0.7 and
# measure P(original_decision | masked). Produces a genuine distribution
# rather than near-binary logprob scores.
engine = SaliencyEngine(client=client, model=MODEL, top_k=2, n_samples=5)

print(f"Model: {MODEL}")
print(f"Running hierarchical occlusion on {len(SEGMENTS)} segments...\n")
print(f"Segments: {len(SEGMENTS)}")
print("Running hierarchical occlusion (n_samples=5, temperature=0.7)...\n")

result = await engine.explain_async(
messages=MESSAGES, # type: ignore[arg-type]
segments=SEGMENTS,
tools=TOOLS, # type: ignore[arg-type]
)

print(f"Original decision: {result.original_decision}\n")
print("=== Pass 1: Segment-level importance ===")
print(f"Decision: {result.original_decision}\n")

print("=== Pass 1: segment-level importance ===")
doc_scores = [s for s in result.top if "__s" not in s.segment_id]
for score in doc_scores:
bar = "█" * int(score.importance * 20)
changed = " flipped" if score.decision_changed else ""
print(f" {score.label:<40} {score.importance:.3f} {bar}{changed}")
bar = "█" * int(score.importance * 30)
changed = " [flipped]" if score.decision_changed else ""
print(f" {score.label:<38} {score.importance:.3f} {bar}{changed}")

print("\n=== Pass 2: Sentence-level (top segments drilled in) ===")
print("\n=== Pass 2: sentence-level (top segments) ===")
sent_scores = [s for s in result.top if "__s" in s.segment_id]
if sent_scores:
for score in sent_scores:
bar = "█" * int(score.importance * 20)
changed = " flipped" if score.decision_changed else ""
print(f" {score.label[:60]:<62} {score.importance:.3f} {bar}{changed}")
bar = "█" * int(score.importance * 30)
changed = " [flipped]" if score.decision_changed else ""
print(f" {score.label[:55]:<57} {score.importance:.3f} {bar}{changed}")
else:
print(" (no multi-sentence segments found)")
print(" (top segments were single sentences)")

print(f"\n{result.summary()}")

Expand Down
112 changes: 88 additions & 24 deletions src/motive/saliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _extract_tool_name(response: Any) -> str:
def _logprob_for_tool(response: Any, tool_name: str) -> float | None:
"""Sum logprobs over the tokens that spell out `tool_name` in the response.

Tool names are multi-token (e.g. "escalate_to_human" ["escal","ate","_to","_human"]),
Tool names are multi-token (e.g. "escalate_to_human" -> ["escal","ate","_to","_human"]),
so we reconstruct the full generated string, locate the tool name span, and sum
the logprobs of every token that overlaps with it.
"""
Expand Down Expand Up @@ -107,20 +107,35 @@ class SaliencyEngine:
Pass 1 masks at segment (document/message) level.
Pass 2 drills into the top_k segments at sentence level.

Importance is the drop in log-probability of the original tool choice
when a segment is masked. Falls back to binary (flipped = 1.0,
unchanged = 0.0) when logprobs are unavailable.
Scoring modes
-------------
n_samples=1 (default):
Single deterministic run (temperature=0). Importance = drop in
log-probability of the original tool when the segment is masked.
Fast but can produce near-binary scores when the model is very confident.

n_samples>1:
Monte Carlo sampling (temperature=sample_temperature). Runs each
masked context n_samples times and measures how often the original
decision survives. Importance = P(original | full) - P(original | masked).
Slower but produces a genuine probability distribution (e.g. 0.0, 0.2,
0.4, 0.6, 0.8, 1.0 with n_samples=5). Use this when logprob drops are
too small to distinguish importance levels.
"""

def __init__(
self,
client: AsyncOpenAI,
model: str,
top_k: int = _TOP_K_DEFAULT,
n_samples: int = 1,
sample_temperature: float = 0.7,
) -> None:
self.client = client
self.model = model
self.top_k = top_k
self.n_samples = n_samples
self.sample_temperature = sample_temperature

def explain(
self,
Expand All @@ -139,16 +154,24 @@ async def explain_async(
tools: list[ChatCompletionToolParam],
tool_choice: ToolChoice = "auto",
) -> SaliencyResult:
"""Async entry point."""
original_resp = await self._call(messages, tools, tool_choice, logprobs=True)
original_tool = _extract_tool_name(original_resp)
original_lp = _logprob_for_tool(original_resp, original_tool)
"""Run explanation asynchronously."""
if self.n_samples > 1:
original_prob = await self._sample_probability(messages, tools, tool_choice)
original_tool = await self._get_tool(messages, tools, tool_choice)
original_lp = None
else:
original_resp = await self._call(messages, tools, tool_choice, logprobs=True)
original_tool = _extract_tool_name(original_resp)
original_lp = _logprob_for_tool(original_resp, original_tool)
original_prob = 1.0

# Pass 1: segment level
pass1 = await self._occlusion_pass(messages, segments, tools, tool_choice, original_tool, original_lp)
pass1 = await self._occlusion_pass(
messages, segments, tools, tool_choice, original_tool, original_lp, original_prob
)
pass1 = _normalise(pass1)

# Pass 2: sentence level on top_k segments that actually had non-zero importance
# Pass 2: sentence level on top_k segments with non-zero importance
top_ids = {
s.segment_id
for s in sorted(pass1, key=lambda s: s.importance, reverse=True)[: self.top_k]
Expand All @@ -171,7 +194,7 @@ async def explain_async(
for i, s in enumerate(sentences)
]
sub_scores = await self._occlusion_pass(
messages, sub_segments, tools, tool_choice, original_tool, original_lp
messages, sub_segments, tools, tool_choice, original_tool, original_lp, original_prob
)
pass2.extend(_normalise(sub_scores))

Expand All @@ -189,9 +212,10 @@ async def _occlusion_pass(
tool_choice: ToolChoice,
original_tool: str,
original_lp: float | None,
original_prob: float,
) -> list[SaliencyScore]:
tasks = [
self._score_segment(messages, segments, i, tools, tool_choice, original_tool, original_lp)
self._score_segment(messages, segments, i, tools, tool_choice, original_tool, original_lp, original_prob)
for i in range(len(segments))
]
return list(await asyncio.gather(*tasks))
Expand All @@ -205,20 +229,26 @@ async def _score_segment(
tool_choice: ToolChoice,
original_tool: str,
original_lp: float | None,
original_prob: float,
) -> SaliencyScore:
masked_msgs = _mask_messages(messages, segments, mask_idx)
resp = await self._call(masked_msgs, tools, tool_choice, logprobs=True)
masked_tool = _extract_tool_name(resp)

if original_lp is not None:
masked_lp = _logprob_for_tool(resp, original_tool)
importance = (
max(0.0, original_lp - masked_lp)
if masked_lp is not None
else (1.0 if masked_tool != original_tool else 0.0)
)

if self.n_samples > 1:
masked_prob = await self._sample_probability(masked_msgs, tools, tool_choice, target_tool=original_tool)
importance = max(0.0, original_prob - masked_prob)
masked_tool = original_tool if masked_prob >= 0.5 else "other"
else:
importance = 1.0 if masked_tool != original_tool else 0.0
resp = await self._call(masked_msgs, tools, tool_choice, logprobs=True)
masked_tool = _extract_tool_name(resp)
if original_lp is not None:
masked_lp = _logprob_for_tool(resp, original_tool)
importance = (
max(0.0, original_lp - masked_lp)
if masked_lp is not None
else (1.0 if masked_tool != original_tool else 0.0)
)
else:
importance = 1.0 if masked_tool != original_tool else 0.0

seg = segments[mask_idx]
return SaliencyScore(
Expand All @@ -230,12 +260,46 @@ async def _score_segment(
masked_decision=masked_tool,
)

async def _sample_probability(
self,
messages: list[ChatCompletionMessageParam],
tools: list[ChatCompletionToolParam],
tool_choice: ToolChoice,
target_tool: str | None = None,
) -> float:
"""Estimate P(target_tool | messages) via n_samples draws at sample_temperature."""
resps = await asyncio.gather(
*[
self._call(messages, tools, tool_choice, temperature=self.sample_temperature)
for _ in range(self.n_samples)
]
)
if target_tool is None:
# First call: use the majority decision as target
target_tool = max(
{_extract_tool_name(r) for r in resps},
key=lambda t: sum(1 for r in resps if _extract_tool_name(r) == t),
)
matches = sum(1 for r in resps if _extract_tool_name(r) == target_tool)
return matches / self.n_samples

async def _get_tool(
self,
messages: list[ChatCompletionMessageParam],
tools: list[ChatCompletionToolParam],
tool_choice: ToolChoice,
) -> str:
"""Get the deterministic tool choice (temperature=0) for the original context."""
resp = await self._call(messages, tools, tool_choice)
return _extract_tool_name(resp)

async def _call(
self,
messages: list[ChatCompletionMessageParam],
tools: list[ChatCompletionToolParam],
tool_choice: ToolChoice,
logprobs: bool = False,
temperature: float = 0.0,
) -> Any:
return await self.client.chat.completions.create(
model=self.model,
Expand All @@ -244,5 +308,5 @@ async def _call(
tool_choice=tool_choice,
logprobs=logprobs,
top_logprobs=5 if logprobs else None,
temperature=0,
temperature=temperature,
)
Loading