Skip to content

Commit 3bd3f53

Browse files
committed
Ensure we correctly do effect validation of tools
1 parent 745d0c2 commit 3bd3f53

2 files changed

Lines changed: 70 additions & 27 deletions

File tree

reboot/mcp/server.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,21 @@
2424
from mcp.shared.message import ServerMessageMetadata
2525
from rbt.mcp.v1.session_rbt import Session
2626
from reboot.aio.applications import Application
27-
from reboot.aio.contexts import WorkflowContext
27+
from reboot.aio.contexts import EffectValidation, WorkflowContext
2828
from reboot.aio.external import ExternalContext
2929
from reboot.aio.types import StateRef
30+
from reboot.aio.workflows import at_least_once
3031
from reboot.mcp.event_store import DurableEventStore, replay
3132
from reboot.mcp.servicers.session import (
3233
SessionServicer,
3334
_servers,
3435
_context,
3536
)
36-
from reboot.aio.workflows import at_least_once
3737
from reboot.mcp.servicers.stream import StreamServicer
3838
from reboot.std.collections.v1 import sorted_map
3939
from rebootdev.aio.headers import CONSENSUS_ID_HEADER, STATE_REF_HEADER
4040
from rebootdev.memoize.v1.memoize_rbt import Memoize
41+
from rebootdev.settings import DOCS_BASE_URL
4142
from starlette.applications import Starlette
4243
from starlette.requests import Request
4344
from starlette.responses import Response, StreamingResponse
@@ -663,12 +664,12 @@ def _wrap_tool(fn: mcp.types.AnyFunction) -> mcp.types.AnyFunction:
663664

664665
wrapper_signature = signature.replace(parameters=wrapper_parameters)
665666

666-
async def wrapper(ctx: fastmcp.Context, *args, **kwargs):
667-
668-
context: WorkflowContext | None = _context.get()
669-
670-
assert context is not None
671-
667+
async def wrapper(
668+
ctx: fastmcp.Context,
669+
context: WorkflowContext,
670+
*args,
671+
**kwargs,
672+
):
672673
# To account for the lack of "intersection" types in
673674
# Python (which is actively being worked on), we instead
674675
# create a new dynamic `DurableContext` instance that
@@ -938,19 +939,49 @@ async def send_request_and_wait_for_result():
938939

939940
return fn(**dict(bound.arguments))
940941
except:
941-
# TODO: print stack trace after we've fixed `memoize`
942-
# effect validation bug.
943-
#
944-
# import traceback
945-
# traceback.print_exc()
942+
import traceback
943+
traceback.print_exc()
946944
raise
947945

948-
setattr(wrapper, "__signature__", wrapper_signature)
949-
wrapper.__name__ = fn.__name__
950-
wrapper.__doc__ = fn.__doc__
946+
async def wrapper_validating_effects(
947+
ctx: fastmcp.Context,
948+
*args,
949+
**kwargs,
950+
):
951+
context: WorkflowContext | None = _context.get()
952+
953+
assert context is not None
954+
955+
# Checkpoint the context since it is the `IdempotencyManager`.
956+
checkpoint = context.checkpoint()
957+
958+
result = await wrapper(ctx, context, *args, **kwargs)
959+
960+
if context._effect_validation == EffectValidation.DISABLED:
961+
return result
962+
963+
# Effect validation is enabled.
964+
logger.info(
965+
f"Re-running tool '{fn.__name__}' "
966+
f"to validate effects. See {DOCS_BASE_URL}/develop/side_effects "
967+
"for more information."
968+
)
969+
970+
# Restore the context to the checkpoint we took above so we
971+
# can re-execute `callable` as though it is being retried from
972+
# scratch.
973+
context.restore(checkpoint)
974+
975+
# TODO: check if `result` is different (we don't do this for
976+
# other effect validation so we're also not doing it now).
977+
978+
return await wrapper(ctx, context, *args, **kwargs)
951979

952-
return wrapper
980+
setattr(wrapper_validating_effects, "__signature__", wrapper_signature)
981+
wrapper_validating_effects.__name__ = fn.__name__
982+
wrapper_validating_effects.__doc__ = fn.__doc__
953983

984+
return wrapper_validating_effects
954985

955986

956987
class StreamableHTTPASGIApp:

reboot/mcp/servicers/session.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import anyio
12
import mcp.types
23
import pickle
34
from contextvars import ContextVar
4-
from anyio import create_memory_object_stream
55
from anyio.streams.memory import (
66
MemoryObjectReceiveStream,
77
MemoryObjectSendStream,
@@ -76,9 +76,10 @@ def _get_request_streams(self, request_id: mcp.types.RequestId):
7676
# Create streams for communicating with MCP server.
7777
self._request_streams[request_id] = Streams(
7878
refs=1, # Initial reference count.
79-
read_stream=create_memory_object_stream[
79+
read_stream=anyio.create_memory_object_stream[
8080
SessionMessage | Exception](),
81-
write_stream=create_memory_object_stream[SessionMessage](),
81+
write_stream=anyio.create_memory_object_stream[
82+
SessionMessage](),
8283
)
8384
else:
8485
self._request_streams[request_id].refs += 1
@@ -121,8 +122,13 @@ async def HandleMessage(
121122
)
122123

123124
async def send_and_receive():
124-
125-
await read_stream_send.send(message)
125+
try:
126+
await read_stream_send.send(message)
127+
except anyio.ClosedResourceError:
128+
# Stream is closed, we must be re-executing
129+
# this function due to effect validation, just
130+
# return.
131+
return
126132

127133
async for write_message in write_stream_receive:
128134
logger.debug(
@@ -141,7 +147,9 @@ async def send_and_receive():
141147
write_message.message.root.id = event_id
142148
related_request_id = write_message.metadata.related_request_id
143149
assert related_request_id is not None
144-
self._write_request_ids[event_id] = (write_request_id, related_request_id)
150+
self._write_request_ids[event_id] = (
151+
write_request_id, related_request_id
152+
)
145153

146154
await stream.per_workflow(event_id).Put(
147155
context,
@@ -153,6 +161,7 @@ async def send_and_receive():
153161
write_message.message.root,
154162
mcp.types.JSONRPCResponse | mcp.types.JSONRPCError,
155163
):
164+
await read_stream_send.aclose()
156165
break
157166

158167
await at_least_once(
@@ -161,8 +170,10 @@ async def send_and_receive():
161170
send_and_receive,
162171
)
163172

164-
await read_stream_send.aclose()
165-
173+
# NOTE: need to await `run_task` within the
174+
# `self._get_request_streams()` context manager so
175+
# that we continue to use the same streams between
176+
# this function and `Run()`.
166177
await run_task
167178

168179
logger.debug(f"Completed ({type(message).__name__}): {message}")
@@ -221,10 +232,11 @@ async def Run(
221232
write_stream_send, _ = write_stream
222233

223234
async def server_run():
224-
global _servers
225-
server = _servers[path]
235+
assert _context.get() is None
226236
_context.set(context)
227237
try:
238+
global _servers
239+
server = _servers[path]
228240
await server.run(
229241
read_stream_receive,
230242
write_stream_send,

0 commit comments

Comments
 (0)