Skip to content

Commit b26af99

Browse files
dcramercodex
andcommitted
Fix scheduled reply routing and DM chat-type propagation
- treat replies to bot messages as active even when thread index misses - normalize thread index keys to strings for consistent lookup - propagate schedule chat_type through CLI/RPC/store into scheduled session context - add regression tests for schedule chat_type and reply-skip policy Co-Authored-By: GPT-5 Codex <noreply@openai.com>
1 parent 83b6d2a commit b26af99

11 files changed

Lines changed: 125 additions & 13 deletions

File tree

packages/ash-sandbox-cli/src/ash_sandbox_cli/commands/schedule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _get_context() -> dict[str, str]:
2222
return {
2323
"user_id": context.get("user_id") or "",
2424
"chat_id": context.get("chat_id") or "",
25+
"chat_type": context.get("chat_type") or "",
2526
"chat_title": context.get("chat_title") or "",
2627
"provider": context.get("provider") or "",
2728
"username": context.get("username") or "",
@@ -257,6 +258,8 @@ def create(
257258
params["cron"] = cron
258259
if ctx["chat_title"]:
259260
params["chat_title"] = ctx["chat_title"]
261+
if ctx["chat_type"]:
262+
params["chat_type"] = ctx["chat_type"]
260263
if ctx["user_id"]:
261264
params["user_id"] = ctx["user_id"]
262265
if ctx["username"]:

src/ash/chats/thread_index.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,28 @@ def resolve_thread_id(
5151
returns the same thread_id. Otherwise returns external_id as a new thread.
5252
"""
5353
index = self._ensure_loaded()
54+
external_key = str(external_id)
55+
reply_key = str(reply_to_external_id) if reply_to_external_id else None
5456

55-
if reply_to_external_id:
56-
parent_thread = index.get(reply_to_external_id)
57+
if reply_key:
58+
parent_thread = index.get(reply_key)
5759
if parent_thread:
5860
logger.debug(
5961
"Message %s joins thread %s (via reply to %s)",
60-
external_id,
62+
external_key,
6163
parent_thread,
62-
reply_to_external_id,
64+
reply_key,
6365
)
6466
return parent_thread
6567
logger.debug(
6668
"Message %s replied to unknown message %s, starting new thread",
67-
external_id,
68-
reply_to_external_id,
69+
external_key,
70+
reply_key,
6971
)
7072

7173
# Start new thread using this message's ID as the thread_id
72-
logger.debug("Message %s starts new thread", external_id)
73-
return external_id
74+
logger.debug("Message %s starts new thread", external_key)
75+
return external_key
7476

7577
def register_message(self, external_id: str, thread_id: str) -> None:
7678
"""Register a message in a thread.
@@ -81,11 +83,13 @@ def register_message(self, external_id: str, thread_id: str) -> None:
8183
"""
8284
with self._lock:
8385
index = self._ensure_loaded()
84-
if external_id not in index:
85-
index[external_id] = thread_id
86+
external_key = str(external_id)
87+
thread_key = str(thread_id)
88+
if external_key not in index:
89+
index[external_key] = thread_key
8690
self._manager.save()
8791
logger.debug(
88-
"Registered message %s in thread %s", external_id, thread_id
92+
"Registered message %s in thread %s", external_key, thread_key
8993
)
9094

9195
def get_thread_id(self, external_id: str) -> str | None:
@@ -98,4 +102,4 @@ def get_thread_id(self, external_id: str) -> str | None:
98102
The thread_id if the message is registered, None otherwise.
99103
"""
100104
index = self._ensure_loaded()
101-
return index.get(external_id)
105+
return index.get(str(external_id))

src/ash/providers/telegram/handlers/session_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ async def should_skip_reply(self, message: IncomingMessage) -> bool:
372372
return False
373373
if message.metadata.get("was_mentioned", False):
374374
return False
375+
if message.metadata.get("is_reply_to_bot", False):
376+
return False
375377

376378
# Check thread index first
377379
thread_index = self.get_thread_index(message.chat_id)

src/ash/providers/telegram/provider.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,12 +471,14 @@ def _to_incoming_message(
471471
images: list[ImageAttachment] | None = None,
472472
*,
473473
was_mentioned: bool = False,
474+
is_reply_to_bot: bool = False,
474475
) -> IncomingMessage:
475476
"""Convert a Telegram message to an IncomingMessage."""
476477
metadata = {
477478
"chat_type": message.chat.type,
478479
"chat_title": message.chat.title,
479480
"was_mentioned": was_mentioned,
481+
"is_reply_to_bot": is_reply_to_bot,
480482
}
481483
# Include thread_id for forum topics (supergroups with topics enabled)
482484
if message.message_thread_id is not None:
@@ -621,6 +623,7 @@ async def handle_photo(message: TelegramMessage) -> None:
621623
# Strip bot mention from caption if in group
622624
is_group = message.chat.type in ("group", "supergroup")
623625
was_mentioned = is_group and self._is_mentioned(message)
626+
is_reply_to_bot = is_group and self._is_reply(message)
624627
caption = message.caption or ""
625628
if is_group and caption:
626629
caption = self._strip_mention(caption)
@@ -632,6 +635,7 @@ async def handle_photo(message: TelegramMessage) -> None:
632635
caption,
633636
images=[image],
634637
was_mentioned=was_mentioned,
638+
is_reply_to_bot=is_reply_to_bot,
635639
)
636640

637641
if self._handler:
@@ -655,10 +659,16 @@ async def handle_message(message: TelegramMessage) -> None:
655659
# Strip bot mention from text if in group
656660
is_group = message.chat.type in ("group", "supergroup")
657661
was_mentioned = is_group and self._is_mentioned(message)
662+
is_reply_to_bot = is_group and self._is_reply(message)
658663
text = self._strip_mention(message.text) if is_group else message.text
659664

660665
incoming = self._to_incoming_message(
661-
message, user_id, username, text, was_mentioned=was_mentioned
666+
message,
667+
user_id,
668+
username,
669+
text,
670+
was_mentioned=was_mentioned,
671+
is_reply_to_bot=is_reply_to_bot,
662672
)
663673
# Add processing mode to metadata
664674
incoming.metadata["processing_mode"] = processing_mode

src/ash/rpc/methods/schedule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ async def schedule_create(params: dict[str, Any]) -> dict[str, Any]:
3535
trigger_at: ISO datetime for one-shot (mutually exclusive with cron)
3636
cron: Cron expression for periodic (mutually exclusive with trigger_at)
3737
chat_id: Target chat ID (required)
38+
chat_type: Chat type for policy checks at execution time (optional)
3839
provider: Provider name (required)
3940
user_id: User ID
4041
username: Username
@@ -67,6 +68,7 @@ async def schedule_create(params: dict[str, Any]) -> dict[str, Any]:
6768
trigger_at=trigger_at,
6869
cron=cron,
6970
chat_id=chat_id,
71+
chat_type=params.get("chat_type"),
7072
chat_title=params.get("chat_title"),
7173
provider=provider,
7274
user_id=params.get("user_id"),

src/ash/scheduling/handler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def format_delay(seconds: float) -> str:
9090
return f"~{days:.1f} days"
9191

9292

93+
def _resolve_chat_type(entry: ScheduleEntry) -> str | None:
94+
"""Resolve effective chat_type for scheduled executions.
95+
96+
New entries should carry chat_type directly. For older Telegram entries
97+
without this field, infer private vs group from Telegram chat ID shape.
98+
"""
99+
if entry.chat_type:
100+
return entry.chat_type
101+
102+
if entry.provider == "telegram" and entry.chat_id:
103+
return "group" if entry.chat_id.startswith("-") else "private"
104+
105+
return None
106+
107+
93108
class MessageSender(Protocol):
94109
"""Protocol for sending messages to a chat. Returns the sent message ID."""
95110

@@ -217,6 +232,7 @@ async def handle(self, entry: ScheduleEntry) -> None:
217232
# Populate context so system prompt builder includes full context
218233
session.context.username = entry.username or ""
219234
session.context.is_scheduled_task = True
235+
session.context.chat_type = _resolve_chat_type(entry)
220236
if entry.chat_title:
221237
session.context.chat_title = entry.chat_title
222238

src/ash/scheduling/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ScheduleEntry:
3131
timezone: str | None = None
3232
# Context for routing response back
3333
chat_id: str | None = None
34+
chat_type: str | None = None # "private", "group", "supergroup", ...
3435
chat_title: str | None = None # Friendly name for the chat
3536
user_id: str | None = None
3637
username: str | None = None # For @mentions in response
@@ -234,6 +235,8 @@ def to_dict(self) -> dict[str, Any]:
234235
# Context fields
235236
if self.chat_id:
236237
data["chat_id"] = self.chat_id
238+
if self.chat_type:
239+
data["chat_type"] = self.chat_type
237240
if self.chat_title:
238241
data["chat_title"] = self.chat_title
239242
if self.user_id:
@@ -290,6 +293,7 @@ def parse_datetime(key: str) -> datetime | None:
290293
"last_run",
291294
"timezone",
292295
"chat_id",
296+
"chat_type",
293297
"chat_title",
294298
"user_id",
295299
"username",
@@ -306,6 +310,7 @@ def parse_datetime(key: str) -> datetime | None:
306310
last_run=last_run,
307311
timezone=data.get("timezone"),
308312
chat_id=data.get("chat_id"),
313+
chat_type=data.get("chat_type"),
309314
chat_title=data.get("chat_title"),
310315
user_id=data.get("user_id"),
311316
username=data.get("username"),

tests/test_sandbox_cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ def _context_token(
1515
*,
1616
effective_user_id: str = "user123",
1717
chat_id: str | None = "chat456",
18+
chat_type: str | None = "private",
1819
provider: str | None = "telegram",
1920
source_username: str | None = "testuser",
2021
timezone: str | None = "UTC",
2122
) -> str:
2223
return get_default_context_token_service().issue(
2324
effective_user_id=effective_user_id,
2425
chat_id=chat_id,
26+
chat_type=chat_type,
2527
provider=provider,
2628
source_username=source_username,
2729
timezone=timezone,
@@ -82,6 +84,7 @@ def test_create_one_shot(self, cli_runner, mock_rpc):
8284
params = call_args[1]
8385
assert params["message"] == "Test reminder"
8486
assert params["chat_id"] == "chat456"
87+
assert params["chat_type"] == "private"
8588
assert params["provider"] == "telegram"
8689

8790
def test_create_periodic(self, cli_runner, mock_rpc):

tests/test_schedule.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,41 @@ async def test_handler_uses_configured_timezone(self):
16271627
# The previous 8 AM LA should show as 08:00, not UTC equivalent
16281628
assert "08:00" in wrapped_message
16291629

1630+
@pytest.mark.asyncio
1631+
async def test_handler_sets_session_chat_type_for_dm_policy(self):
1632+
"""Scheduled tasks in DMs preserve private chat_type for skill policy checks."""
1633+
from unittest.mock import AsyncMock, MagicMock
1634+
1635+
from ash.scheduling import ScheduledTaskHandler
1636+
1637+
mock_agent = MagicMock()
1638+
mock_response = MagicMock()
1639+
mock_response.text = "ok"
1640+
mock_agent.process_message = AsyncMock(return_value=mock_response)
1641+
1642+
mock_sender = AsyncMock(return_value="msg_123")
1643+
handler = ScheduledTaskHandler(
1644+
agent=mock_agent,
1645+
senders={"telegram": mock_sender},
1646+
timezone="UTC",
1647+
)
1648+
1649+
entry = ScheduleEntry(
1650+
id="dm_ctx_1",
1651+
message="Check my calendar",
1652+
trigger_at=datetime.now(UTC) - timedelta(minutes=1),
1653+
provider="telegram",
1654+
chat_id="123456789",
1655+
chat_type="private",
1656+
user_id="42",
1657+
)
1658+
1659+
await handler.handle(entry)
1660+
1661+
call_args = mock_agent.process_message.call_args
1662+
session = call_args.args[1]
1663+
assert session.context.chat_type == "private"
1664+
16301665

16311666
class TestStalenessGuard:
16321667
"""Tests for the staleness guard in ScheduleWatcher."""

tests/test_sessions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,3 +1096,32 @@ def test_session_context_has_no_chat_history_field(self):
10961096

10971097
ctx = SessionContext()
10981098
assert not hasattr(ctx, "chat_history")
1099+
1100+
1101+
class TestGroupReplySkipPolicy:
1102+
@pytest.mark.asyncio
1103+
async def test_reply_to_bot_is_not_skipped_when_thread_unknown(self):
1104+
from ash.config.models import ConversationConfig
1105+
from ash.providers.base import IncomingMessage
1106+
from ash.providers.telegram.handlers.session_handler import SessionHandler
1107+
1108+
handler = SessionHandler(
1109+
provider_name="telegram",
1110+
config=None,
1111+
conversation_config=ConversationConfig(),
1112+
)
1113+
1114+
message = IncomingMessage(
1115+
id="201",
1116+
chat_id="-1001",
1117+
user_id="u1",
1118+
text="follow up",
1119+
reply_to_message_id="200",
1120+
metadata={
1121+
"chat_type": "group",
1122+
"was_mentioned": False,
1123+
"is_reply_to_bot": True,
1124+
},
1125+
)
1126+
1127+
assert await handler.should_skip_reply(message) is False

0 commit comments

Comments
 (0)