Skip to content

Commit 5affe13

Browse files
dcramercodex
andcommitted
Stabilize DM threading and enforce mutation confirmation proof
Co-Authored-By: GPT-5 Codex <codex@openai.com>
1 parent 8db630e commit 5affe13

11 files changed

Lines changed: 666 additions & 11 deletions

File tree

packages/ash-sandbox-cli/src/ash_sandbox_cli/commands/capability.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,20 @@ def invoke_capability(
122122
str | None,
123123
typer.Option("--account", help="Optional linked account alias"),
124124
] = None,
125+
mutation_plan_id: Annotated[
126+
str | None,
127+
typer.Option(
128+
"--plan-id",
129+
help="Optional mutation plan id for host confirmation proof",
130+
),
131+
] = None,
132+
target_fingerprint: Annotated[
133+
str | None,
134+
typer.Option(
135+
"--target-fingerprint",
136+
help="Optional target fingerprint for host confirmation proof",
137+
),
138+
] = None,
125139
) -> None:
126140
"""Invoke one capability operation."""
127141
try:
@@ -139,6 +153,10 @@ def invoke_capability(
139153
params["idempotency_key"] = idempotency_key
140154
if account:
141155
params["account_ref"] = account
156+
if mutation_plan_id:
157+
params["mutation_plan_id"] = mutation_plan_id
158+
if target_fingerprint:
159+
params["target_fingerprint"] = target_fingerprint
142160

143161
result = _call("capability.invoke", params)
144162
request_id = result.get("request_id", "?")

specs/sessions.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ Ash uses **per-user sessions scoped to thread** for group chats:
1111
- Standalone `@mention` messages create a new thread (thread_id = message external_id)
1212
- Replies follow the parent thread via `ThreadIndex`
1313
- Session key for groups: `telegram_{chat_id}_{user_id}_{thread_id}` (users in the same thread do not share session state)
14-
- DMs use a single session: `telegram_{chat_id}_{user_id}` (no thread_id)
14+
15+
For DMs, Ash uses **hybrid active-thread routing**:
16+
17+
- Replies follow parent thread via `ThreadIndex`.
18+
- Non-reply messages continue on the active DM thread when it is fresh.
19+
- A new thread is created when no active thread is available (or after explicit new-topic intent/timeout rollover).
20+
- Session key for DM turns remains thread-scoped when a thread_id exists: `telegram_{chat_id}_{user_id}_{thread_id}`.
1521

1622
## File Structure
1723

src/ash/capabilities/manager.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CapabilityDefinition,
3030
CapabilityInvokeResult,
3131
)
32+
from ash.chats import ChatStateManager
3233

3334
_NAMESPACED_CAPABILITY_ID = re.compile(r"^[a-z0-9][a-z0-9_-]*\.[a-z0-9][a-z0-9_-]*$")
3435
_NAMESPACE = re.compile(r"^[a-z0-9][a-z0-9_-]*$")
@@ -687,6 +688,8 @@ async def invoke(
687688
source_display_name: str | None = None,
688689
idempotency_key: str | None = None,
689690
account_ref: str | None = None,
691+
mutation_plan_id: str | None = None,
692+
target_fingerprint: str | None = None,
690693
) -> CapabilityInvokeResult:
691694
"""Invoke one capability operation under caller scope."""
692695
normalized_user_id = _required_text(
@@ -709,6 +712,8 @@ async def invoke(
709712
normalized_source_display_name = _optional_text(source_display_name)
710713
normalized_idempotency_key = _optional_text(idempotency_key)
711714
normalized_account_ref = _optional_text(account_ref)
715+
normalized_mutation_plan_id = _optional_text(mutation_plan_id)
716+
normalized_target_fingerprint = _optional_text(target_fingerprint)
712717

713718
account_ref: str | None = None
714719
provider_impl: CapabilityProvider | None = None
@@ -774,6 +779,22 @@ async def invoke(
774779
source_username=normalized_source_username,
775780
source_display_name=normalized_source_display_name,
776781
)
782+
confirmed_plan_id: str | None = None
783+
if _requires_mutation_confirmation(
784+
capability_id=normalized_capability_id,
785+
operation=normalized_operation,
786+
provider=normalized_provider,
787+
mutating=bool(op.mutating),
788+
):
789+
confirmed_plan_id = _assert_mutation_confirmation_proof(
790+
chat_id=normalized_chat_id,
791+
thread_id=normalized_thread_id,
792+
capability_id=normalized_capability_id,
793+
operation=normalized_operation,
794+
mutation_plan_id=normalized_mutation_plan_id,
795+
target_fingerprint=normalized_target_fingerprint,
796+
)
797+
777798
raw_output = await self._provider_invoke(
778799
provider_impl,
779800
capability_id=normalized_capability_id,
@@ -791,6 +812,11 @@ async def invoke(
791812
)
792813

793814
request_id = f"cap_{secrets.token_hex(8)}"
815+
if confirmed_plan_id and normalized_chat_id:
816+
_mark_mutation_plan_executed(
817+
chat_id=normalized_chat_id,
818+
plan_id=confirmed_plan_id,
819+
)
794820

795821
return CapabilityInvokeResult(
796822
request_id=request_id,
@@ -1135,6 +1161,72 @@ def _find_sensitive_key_path(value: Any, path: str = "output") -> str | None:
11351161
return None
11361162

11371163

1164+
def _requires_mutation_confirmation(
1165+
*,
1166+
capability_id: str,
1167+
operation: str,
1168+
provider: str | None,
1169+
mutating: bool,
1170+
) -> bool:
1171+
if not mutating:
1172+
return False
1173+
if provider != "telegram":
1174+
return False
1175+
return capability_id == "gog.email" and operation in {
1176+
"archive_messages",
1177+
"update_labels",
1178+
}
1179+
1180+
1181+
def _assert_mutation_confirmation_proof(
1182+
*,
1183+
chat_id: str | None,
1184+
thread_id: str | None,
1185+
capability_id: str,
1186+
operation: str,
1187+
mutation_plan_id: str | None,
1188+
target_fingerprint: str | None,
1189+
) -> str:
1190+
normalized_chat_id = _optional_text(chat_id)
1191+
if normalized_chat_id is None:
1192+
raise CapabilityError(
1193+
"capability_mutation_not_confirmed",
1194+
"mutating operation requires chat-scoped confirmation proof",
1195+
)
1196+
1197+
manager = ChatStateManager(provider="telegram", chat_id=normalized_chat_id)
1198+
state = manager.load()
1199+
state.prune_expired_mutation_confirmations()
1200+
confirmed = state.find_confirmed_mutation(
1201+
capability_id=capability_id,
1202+
operation=operation,
1203+
target_fingerprint=target_fingerprint,
1204+
thread_id=thread_id,
1205+
)
1206+
if confirmed is None:
1207+
raise CapabilityError(
1208+
"capability_mutation_not_confirmed",
1209+
(
1210+
"mutation requires prior confirmation in chat history; "
1211+
"show targets and get explicit user confirm first"
1212+
),
1213+
)
1214+
if mutation_plan_id and mutation_plan_id != confirmed.plan_id:
1215+
raise CapabilityError(
1216+
"capability_mutation_plan_mismatch",
1217+
"provided mutation_plan_id does not match confirmed chat plan",
1218+
)
1219+
manager.save()
1220+
return confirmed.plan_id
1221+
1222+
1223+
def _mark_mutation_plan_executed(*, chat_id: str, plan_id: str) -> None:
1224+
manager = ChatStateManager(provider="telegram", chat_id=chat_id)
1225+
state = manager.load()
1226+
if state.mark_mutation_executed(plan_id=plan_id):
1227+
manager.save()
1228+
1229+
11381230
async def create_capability_manager(
11391231
*,
11401232
providers: list[CapabilityProvider] | None = None,

src/ash/chats/models.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Chat state models."""
22

3-
from datetime import UTC, datetime
3+
from datetime import UTC, datetime, timedelta
44

55
from pydantic import BaseModel, Field
66

@@ -28,12 +28,32 @@ class ChatInfo(BaseModel):
2828
title: str | None = None
2929

3030

31+
class MutationConfirmation(BaseModel):
32+
"""Chat-scoped proof that a mutating operation was shown and confirmed."""
33+
34+
plan_id: str
35+
capability_id: str
36+
operation: str
37+
status: str = "presented" # presented | confirmed | executed
38+
presented_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
39+
expires_at: datetime
40+
confirmed_at: datetime | None = None
41+
executed_at: datetime | None = None
42+
target_fingerprint: str | None = None
43+
thread_id: str | None = None
44+
summary: str | None = None
45+
46+
3147
class ChatState(BaseModel):
3248
"""State for a chat, stored in state.json."""
3349

3450
chat: ChatInfo
3551
participants: list[Participant] = Field(default_factory=list)
3652
thread_index: dict[str, str] = Field(default_factory=dict)
53+
active_thread_id: str | None = None
54+
active_thread_updated_at: datetime | None = None
55+
active_thread_reason: str | None = None
56+
mutation_confirmations: list[MutationConfirmation] = Field(default_factory=list)
3757
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
3858
graph_chat_id: str | None = None # Reference to graph ChatEntry.id
3959

@@ -73,3 +93,144 @@ def update_participant(
7393

7494
self.updated_at = now
7595
return participant
96+
97+
def set_active_thread(
98+
self,
99+
thread_id: str,
100+
*,
101+
reason: str,
102+
now: datetime | None = None,
103+
) -> None:
104+
"""Record the active thread for chat-scoped DM routing."""
105+
ts = now or datetime.now(UTC)
106+
self.active_thread_id = str(thread_id)
107+
self.active_thread_updated_at = ts
108+
self.active_thread_reason = reason
109+
self.updated_at = ts
110+
111+
def get_active_thread(
112+
self,
113+
*,
114+
max_age_minutes: int,
115+
now: datetime | None = None,
116+
) -> str | None:
117+
"""Return active thread_id if it is still within the freshness window."""
118+
if not self.active_thread_id or not self.active_thread_updated_at:
119+
return None
120+
ts = now or datetime.now(UTC)
121+
max_age = max(1, int(max_age_minutes))
122+
if ts - self.active_thread_updated_at > timedelta(minutes=max_age):
123+
return None
124+
return self.active_thread_id
125+
126+
def add_mutation_confirmation(
127+
self,
128+
*,
129+
plan_id: str,
130+
capability_id: str,
131+
operation: str,
132+
target_fingerprint: str | None = None,
133+
thread_id: str | None = None,
134+
summary: str | None = None,
135+
ttl_hours: int = 24,
136+
now: datetime | None = None,
137+
) -> MutationConfirmation:
138+
"""Store a mutation confirmation prompt shown to the user."""
139+
ts = now or datetime.now(UTC)
140+
self.prune_expired_mutation_confirmations(now=ts)
141+
confirmation = MutationConfirmation(
142+
plan_id=plan_id,
143+
capability_id=capability_id,
144+
operation=operation,
145+
expires_at=ts + timedelta(hours=max(1, int(ttl_hours))),
146+
target_fingerprint=target_fingerprint,
147+
thread_id=thread_id,
148+
summary=summary,
149+
)
150+
self.mutation_confirmations.append(confirmation)
151+
self.updated_at = ts
152+
return confirmation
153+
154+
def confirm_latest_mutation(
155+
self,
156+
*,
157+
now: datetime | None = None,
158+
thread_id: str | None = None,
159+
) -> MutationConfirmation | None:
160+
"""Confirm the latest non-expired presented mutation plan."""
161+
ts = now or datetime.now(UTC)
162+
self.prune_expired_mutation_confirmations(now=ts)
163+
for confirmation in reversed(self.mutation_confirmations):
164+
if confirmation.status != "presented":
165+
continue
166+
if (
167+
thread_id
168+
and confirmation.thread_id
169+
and confirmation.thread_id != thread_id
170+
):
171+
continue
172+
confirmation.status = "confirmed"
173+
confirmation.confirmed_at = ts
174+
self.updated_at = ts
175+
return confirmation
176+
return None
177+
178+
def find_confirmed_mutation(
179+
self,
180+
*,
181+
capability_id: str,
182+
operation: str,
183+
target_fingerprint: str | None = None,
184+
thread_id: str | None = None,
185+
now: datetime | None = None,
186+
) -> MutationConfirmation | None:
187+
"""Find a non-expired confirmed mutation authorization."""
188+
ts = now or datetime.now(UTC)
189+
self.prune_expired_mutation_confirmations(now=ts)
190+
for confirmation in reversed(self.mutation_confirmations):
191+
if confirmation.status != "confirmed":
192+
continue
193+
if confirmation.capability_id != capability_id:
194+
continue
195+
if confirmation.operation != operation:
196+
continue
197+
if target_fingerprint and confirmation.target_fingerprint:
198+
if confirmation.target_fingerprint != target_fingerprint:
199+
continue
200+
if (
201+
thread_id
202+
and confirmation.thread_id
203+
and confirmation.thread_id != thread_id
204+
):
205+
continue
206+
return confirmation
207+
return None
208+
209+
def mark_mutation_executed(
210+
self,
211+
*,
212+
plan_id: str,
213+
now: datetime | None = None,
214+
) -> bool:
215+
"""Mark a confirmed mutation plan as executed."""
216+
ts = now or datetime.now(UTC)
217+
for confirmation in self.mutation_confirmations:
218+
if confirmation.plan_id != plan_id:
219+
continue
220+
confirmation.status = "executed"
221+
confirmation.executed_at = ts
222+
self.updated_at = ts
223+
return True
224+
return False
225+
226+
def prune_expired_mutation_confirmations(
227+
self,
228+
*,
229+
now: datetime | None = None,
230+
) -> None:
231+
"""Remove expired mutation confirmation entries."""
232+
ts = now or datetime.now(UTC)
233+
kept = [item for item in self.mutation_confirmations if item.expires_at > ts]
234+
if len(kept) != len(self.mutation_confirmations):
235+
self.mutation_confirmations = kept
236+
self.updated_at = ts

src/ash/providers/telegram/handlers/message_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,8 @@ async def _process_single_message_inner(
492492
if isinstance(candidate, IncomingMessage):
493493
message = candidate
494494

495+
self._session_handler.maybe_record_mutation_confirmation_from_user(message)
496+
495497
if await self._try_handle_capability_oauth_callback(message):
496498
return
497499

0 commit comments

Comments
 (0)