Skip to content

Commit 345e2cd

Browse files
dcramercodex
andcommitted
Reuse pending capability auth flows and complete from pasted callback
Avoid regenerating OAuth state when a matching pending flow exists, persist auth begin response fields on flow records, and update Google skill instructions to immediately complete auth when callback URL/code is already present. Also document pending-flow reuse in capabilities spec and add coverage for auth_begin reuse semantics. Co-Authored-By: GPT-5 Codex <codex@openai.com>
1 parent e9563cf commit 345e2cd

5 files changed

Lines changed: 92 additions & 21 deletions

File tree

specs/capabilities.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ Response:
324324
Starts auth for a capability/account and returns an auth flow handle.
325325
For device code flow extensions (`flow_type`, `user_code`, `auth_poll`), see `specs/capability-auth.md`.
326326

327+
If an unexpired pending auth flow already exists for the same caller scope
328+
(`effective_user_id`, `capability`, `account_hint`), the host returns that
329+
existing flow instead of creating a new one.
330+
327331
Request params:
328332

329333
```json

src/ash/capabilities/manager.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,13 @@ async def auth_begin(
255255
normalized_capability_id
256256
)
257257
_assert_chat_type_allowed(definition, normalized_chat_type)
258+
existing_flow = self._find_pending_auth_flow_locked(
259+
user_id=normalized_user_id,
260+
capability_id=normalized_capability_id,
261+
account_hint=normalized_account_hint,
262+
)
263+
if existing_flow is not None:
264+
return _auth_begin_response(existing_flow)
258265

259266
call_context = CapabilityCallContext(
260267
user_id=normalized_user_id,
@@ -278,29 +285,23 @@ async def auth_begin(
278285
datetime.now(UTC) + timedelta(seconds=self._auth_flow_ttl_seconds)
279286
)
280287
flow_type = begin_result.flow_type or "authorization_code"
288+
flow = CapabilityAuthFlow(
289+
flow_id=flow_id,
290+
capability_id=normalized_capability_id,
291+
user_id=normalized_user_id,
292+
account_hint=normalized_account_hint,
293+
expires_at=expires_at,
294+
auth_url=begin_result.auth_url,
295+
flow_state=dict(begin_result.flow_state),
296+
flow_type=flow_type,
297+
user_code=begin_result.user_code,
298+
poll_interval_seconds=begin_result.poll_interval_seconds,
299+
expected_callback_state=begin_result.expected_callback_state,
300+
)
281301
async with self._lock:
282-
self._auth_flows[flow_id] = CapabilityAuthFlow(
283-
flow_id=flow_id,
284-
capability_id=normalized_capability_id,
285-
user_id=normalized_user_id,
286-
account_hint=normalized_account_hint,
287-
expires_at=expires_at,
288-
flow_state=dict(begin_result.flow_state),
289-
flow_type=flow_type,
290-
expected_callback_state=begin_result.expected_callback_state,
291-
)
302+
self._auth_flows[flow_id] = flow
292303

293-
result: dict[str, Any] = {
294-
"flow_id": flow_id,
295-
"auth_url": begin_result.auth_url,
296-
"expires_at": expires_at.isoformat().replace("+00:00", "Z"),
297-
"flow_type": flow_type,
298-
}
299-
if begin_result.user_code is not None:
300-
result["user_code"] = begin_result.user_code
301-
if begin_result.poll_interval_seconds is not None:
302-
result["poll_interval_seconds"] = begin_result.poll_interval_seconds
303-
return result
304+
return _auth_begin_response(flow)
304305

305306
async def auth_complete(
306307
self,
@@ -639,6 +640,24 @@ def _get_definition_and_provider_locked(
639640
provider_impl = self._providers.get(provider_name) if provider_name else None
640641
return definition, provider_impl
641642

643+
def _find_pending_auth_flow_locked(
644+
self,
645+
*,
646+
user_id: str,
647+
capability_id: str,
648+
account_hint: str | None,
649+
) -> CapabilityAuthFlow | None:
650+
matches = [
651+
flow
652+
for flow in self._auth_flows.values()
653+
if flow.user_id == user_id
654+
and flow.capability_id == capability_id
655+
and flow.account_hint == account_hint
656+
]
657+
if not matches:
658+
return None
659+
return max(matches, key=lambda flow: flow.expires_at)
660+
642661
async def _provider_auth_begin(
643662
self,
644663
provider_impl: CapabilityProvider | None,
@@ -864,6 +883,20 @@ def _first_account_ref_locked(
864883
return sorted(refs)[0]
865884

866885

886+
def _auth_begin_response(flow: CapabilityAuthFlow) -> dict[str, Any]:
887+
result: dict[str, Any] = {
888+
"flow_id": flow.flow_id,
889+
"auth_url": flow.auth_url,
890+
"expires_at": flow.expires_at.isoformat().replace("+00:00", "Z"),
891+
"flow_type": flow.flow_type,
892+
}
893+
if flow.user_code is not None:
894+
result["user_code"] = flow.user_code
895+
if flow.poll_interval_seconds is not None:
896+
result["poll_interval_seconds"] = flow.poll_interval_seconds
897+
return result
898+
899+
867900
def _find_sensitive_key_path(value: Any, path: str = "output") -> str | None:
868901
if isinstance(value, dict):
869902
for raw_key, nested in value.items():

src/ash/capabilities/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ class CapabilityAuthFlow:
4242
user_id: str
4343
account_hint: str | None
4444
expires_at: datetime
45+
auth_url: str
4546
flow_state: dict[str, Any] = field(default_factory=dict)
4647
flow_type: str = "authorization_code"
48+
user_code: str | None = None
49+
poll_interval_seconds: int | None = None
4750
expected_callback_state: str | None = None
4851

4952

src/ash/integrations/skills/capabilities/google/SKILL.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ Total: 2 capability(ies)
6161

6262
Run this step for each capability where `Authenticated: no`. If the user's request is setup-only (e.g. "set up my email"), stop after authentication is complete — do not invoke any operations.
6363

64+
Before prompting the user again, check whether the current task already contains a pasted Google callback URL (`http://localhost/?...code=...`) or a raw auth code. If yes:
65+
66+
1. Run `ash-sb capability auth begin -c <capability>` first.
67+
2. Immediately run `ash-sb capability auth complete --flow-id <id> --callback-url <URL>` (or `--code <CODE>`).
68+
3. Re-run `ash-sb capability list` and continue to operations if authenticated.
69+
70+
Do not ask the user for another URL/code when one is already present in the task.
71+
6472
**2a. Begin auth flow**
6573

6674
Use `--account work` or `--account personal` if the user specifies an account preference:

tests/test_capabilities.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,29 @@ async def test_auth_complete_rejects_callback_state_mismatch() -> None:
400400
assert exc_info.value.code == "capability_auth_state_mismatch"
401401

402402

403+
@pytest.mark.asyncio
404+
async def test_auth_begin_reuses_pending_flow_for_same_scope() -> None:
405+
manager = CapabilityManager(auth_flow_ttl_seconds=300)
406+
provider = _RecordingProvider(namespace="gog")
407+
await manager.register_provider(provider)
408+
409+
first = await manager.auth_begin(
410+
capability_id="gog.email",
411+
user_id="user-1",
412+
chat_type="private",
413+
account_hint="work",
414+
)
415+
second = await manager.auth_begin(
416+
capability_id="gog.email",
417+
user_id="user-1",
418+
chat_type="private",
419+
account_hint="work",
420+
)
421+
422+
assert first["flow_id"] == second["flow_id"]
423+
assert len(provider.begin_calls) == 1
424+
425+
403426
@pytest.mark.asyncio
404427
async def test_provider_registration_enforces_namespace_prefix() -> None:
405428
manager = CapabilityManager()

0 commit comments

Comments
 (0)