Skip to content

Commit b04a321

Browse files
GhimBoonnigiva
andcommitted
fix: proper MCP session lifecycle with background tasks
Each MCP connection now runs in its own asyncio.Task that owns the AsyncExitStack. The task enters transports, initializes the ClientSession, signals a ready event, then blocks on a stop event. On shutdown, the exit stack is closed in the same task that opened it, avoiding the cross-task cancel-scope corruption described in #2182. Based on the solution proposed by @nigiva: #2182 (comment) Changes: - Add McpSession dataclass (session.py) with stop_event, close(), and backward-compatible __iter__ for tuple unpacking - Rewrite connect_mcp (server.py) to launch connections in background tasks with ready/stop event coordination - Rewrite disconnect_mcp to use McpSession.close() - Update WebsocketSession.delete() to iterate and close McpSession objects - Catch BaseException (not just Exception) for streamablehttp transport errors (BaseExceptionGroup inherits from BaseException) - Use nested try/finally to ensure exit stack cleanup in all code paths - Return JSONResponse on errors instead of raising HTTPException (avoids issues with BaseHTTPMiddleware.call_next()) - Add comprehensive tests for McpSession lifecycle Co-authored-by: nigiva <nigiva@users.noreply.github.com>
1 parent d69c294 commit b04a321

3 files changed

Lines changed: 392 additions & 104 deletions

File tree

backend/chainlit/server.py

Lines changed: 181 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,8 @@ async def connect_mcp(
12901290
payload: ConnectMCPRequest,
12911291
current_user: UserParam,
12921292
):
1293+
import asyncio
1294+
12931295
from mcp import ClientSession
12941296
from mcp.client.sse import sse_client
12951297
from mcp.client.stdio import (
@@ -1307,7 +1309,7 @@ async def connect_mcp(
13071309
StdioMcpConnection,
13081310
validate_mcp_command,
13091311
)
1310-
from chainlit.session import WebsocketSession
1312+
from chainlit.session import McpSession, WebsocketSession
13111313

13121314
session = WebsocketSession.get_by_id(payload.sessionId)
13131315
context = init_ws_context(session)
@@ -1323,113 +1325,203 @@ async def connect_mcp(
13231325
)
13241326

13251327
mcp_enabled = config.features.mcp.enabled
1326-
if mcp_enabled:
1327-
if payload.name in session.mcp_sessions:
1328-
old_client_session, old_exit_stack = session.mcp_sessions[payload.name]
1329-
if on_mcp_disconnect := config.code.on_mcp_disconnect:
1330-
await on_mcp_disconnect(payload.name, old_client_session)
1328+
if not mcp_enabled:
1329+
raise HTTPException(
1330+
status_code=400,
1331+
detail="This app does not support MCP.",
1332+
)
1333+
1334+
# Disconnect previous session for this name (reconnection)
1335+
if payload.name in session.mcp_sessions:
1336+
old_mcp = session.mcp_sessions.pop(payload.name)
1337+
if on_mcp_disconnect := config.code.on_mcp_disconnect:
13311338
try:
1332-
await old_exit_stack.aclose()
1339+
await on_mcp_disconnect(payload.name, old_mcp.client)
13331340
except Exception:
1334-
pass
1335-
1341+
logger.debug(
1342+
"Error in on_mcp_disconnect callback for %s",
1343+
payload.name,
1344+
exc_info=True,
1345+
)
13361346
try:
1337-
exit_stack = AsyncExitStack()
1338-
mcp_connection: McpConnection
1339-
1340-
if payload.clientType == "sse":
1341-
if not config.features.mcp.sse.enabled:
1342-
raise HTTPException(
1343-
status_code=400,
1344-
detail="SSE MCP is not enabled",
1345-
)
1347+
await old_mcp.close()
1348+
except Exception:
1349+
logger.debug(
1350+
"Error closing old MCP session %s", payload.name, exc_info=True
1351+
)
13461352

1347-
mcp_connection = SseMcpConnection(
1348-
url=payload.url,
1349-
name=payload.name,
1350-
headers=getattr(payload, "headers", None),
1351-
)
1353+
# ── Validate config before launching the background task ──
1354+
mcp_connection: McpConnection
1355+
1356+
if payload.clientType == "sse":
1357+
if not config.features.mcp.sse.enabled:
1358+
raise HTTPException(
1359+
status_code=400,
1360+
detail="SSE MCP is not enabled",
1361+
)
1362+
mcp_connection = SseMcpConnection(
1363+
url=payload.url,
1364+
name=payload.name,
1365+
headers=getattr(payload, "headers", None),
1366+
)
1367+
elif payload.clientType == "stdio":
1368+
if not config.features.mcp.stdio.enabled:
1369+
raise HTTPException(
1370+
status_code=400,
1371+
detail="Stdio MCP is not enabled",
1372+
)
1373+
env_from_cmd, command, args = validate_mcp_command(payload.fullCommand)
1374+
mcp_connection = StdioMcpConnection(
1375+
command=command, args=args, name=payload.name
1376+
)
1377+
elif payload.clientType == "streamable-http":
1378+
if not config.features.mcp.streamable_http.enabled:
1379+
raise HTTPException(
1380+
status_code=400,
1381+
detail="HTTP MCP is not enabled",
1382+
)
1383+
mcp_connection = HttpMcpConnection(
1384+
url=payload.url,
1385+
name=payload.name,
1386+
headers=getattr(payload, "headers", None),
1387+
)
1388+
else:
1389+
raise HTTPException(
1390+
status_code=400,
1391+
detail=f"Unknown MCP client type: {payload.clientType}",
1392+
)
13521393

1353-
transport = await exit_stack.enter_async_context(
1354-
sse_client(
1355-
url=mcp_connection.url,
1356-
headers=mcp_connection.headers,
1394+
# ── Launch the MCP connection in its own background task ──
1395+
#
1396+
# The background task owns the AsyncExitStack: it enters all context
1397+
# managers, calls initialize(), signals ``ready_event``, and then
1398+
# blocks on ``stop_event.wait()``. When the stop event fires the
1399+
# task wakes up and closes the exit stack *in the same task* that
1400+
# opened it — avoiding the cross-task cancel-scope corruption from
1401+
# https://github.com/Chainlit/chainlit/issues/2182.
1402+
1403+
ready_event: asyncio.Event = asyncio.Event()
1404+
stop_event: asyncio.Event = asyncio.Event()
1405+
# Mutable container to pass the ClientSession back from the bg task.
1406+
result_holder: dict[str, object] = {}
1407+
1408+
async def _mcp_session_runner() -> None:
1409+
exit_stack = AsyncExitStack()
1410+
try:
1411+
try:
1412+
if isinstance(mcp_connection, SseMcpConnection):
1413+
transport = await exit_stack.enter_async_context(
1414+
sse_client(
1415+
url=mcp_connection.url,
1416+
headers=mcp_connection.headers,
1417+
)
13571418
)
1358-
)
1359-
elif payload.clientType == "stdio":
1360-
if not config.features.mcp.stdio.enabled:
1361-
raise HTTPException(
1362-
status_code=400,
1363-
detail="Stdio MCP is not enabled",
1419+
elif isinstance(mcp_connection, StdioMcpConnection):
1420+
env = get_default_environment()
1421+
env.update(env_from_cmd)
1422+
server_params = StdioServerParameters(
1423+
command=command, args=args, env=env
13641424
)
1425+
transport = await exit_stack.enter_async_context(
1426+
stdio_client(server_params)
1427+
)
1428+
elif isinstance(mcp_connection, HttpMcpConnection):
1429+
transport = await exit_stack.enter_async_context(
1430+
streamablehttp_client(
1431+
url=mcp_connection.url,
1432+
headers=mcp_connection.headers,
1433+
)
1434+
)
1435+
else:
1436+
raise ValueError(f"Unknown client type: {payload.clientType}")
13651437

1366-
env_from_cmd, command, args = validate_mcp_command(payload.fullCommand)
1367-
mcp_connection = StdioMcpConnection(
1368-
command=command, args=args, name=payload.name
1369-
)
1370-
1371-
env = get_default_environment()
1372-
env.update(env_from_cmd)
1373-
# Create the server parameters
1374-
server_params = StdioServerParameters(
1375-
command=command, args=args, env=env
1376-
)
1377-
1378-
transport = await exit_stack.enter_async_context(
1379-
stdio_client(server_params)
1380-
)
1438+
read, write = transport[:2]
13811439

1382-
elif payload.clientType == "streamable-http":
1383-
if not config.features.mcp.streamable_http.enabled:
1384-
raise HTTPException(
1385-
status_code=400,
1386-
detail="HTTP MCP is not enabled",
1387-
)
1388-
mcp_connection = HttpMcpConnection(
1389-
url=payload.url,
1390-
name=payload.name,
1391-
headers=getattr(payload, "headers", None),
1392-
)
1393-
transport = await exit_stack.enter_async_context(
1394-
streamablehttp_client(
1395-
url=mcp_connection.url,
1396-
headers=mcp_connection.headers,
1440+
mcp_client: ClientSession = await exit_stack.enter_async_context(
1441+
ClientSession(
1442+
read_stream=read,
1443+
write_stream=write,
1444+
sampling_callback=None,
13971445
)
13981446
)
13991447

1400-
# The transport can return (read, write) for stdio, sse
1401-
# Or (read, write, get_session_id) for streamable-http
1402-
# We are only interested in the read and write streams here.
1403-
read, write = transport[:2]
1448+
await mcp_client.initialize()
1449+
result_holder["client"] = mcp_client
14041450

1405-
mcp_session: ClientSession = await exit_stack.enter_async_context(
1406-
ClientSession(
1407-
read_stream=read, write_stream=write, sampling_callback=None
1451+
except BaseException as exc:
1452+
result_holder["error"] = exc
1453+
return # outer finally closes exit_stack
1454+
finally:
1455+
# Always signal the caller so it doesn't wait forever.
1456+
ready_event.set()
1457+
1458+
# ── Keep the task (and the exit stack) alive ──
1459+
try:
1460+
await stop_event.wait()
1461+
except asyncio.CancelledError:
1462+
logger.debug("MCP background task for %r cancelled", payload.name)
1463+
finally:
1464+
# Close exit_stack in ALL paths (error, normal shutdown,
1465+
# cancellation) — always in the same task that opened it.
1466+
logger.debug("Closing MCP exit stack for %r (same-task)", payload.name)
1467+
try:
1468+
await exit_stack.aclose()
1469+
except BaseException:
1470+
logger.debug(
1471+
"Error closing MCP exit stack for %r",
1472+
payload.name,
1473+
exc_info=True,
14081474
)
1409-
)
14101475

1411-
# Initialize the session
1412-
await mcp_session.initialize()
1476+
task = asyncio.create_task(
1477+
_mcp_session_runner(), name=f"mcp-session-{payload.name}"
1478+
)
14131479

1414-
# Store the session
1415-
session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack)
1480+
# Wait for the background task to finish initialisation.
1481+
await ready_event.wait()
14161482

1417-
# Call the callback
1418-
if config.code.on_mcp_connect:
1419-
await config.code.on_mcp_connect(mcp_connection, mcp_session)
1483+
if "error" in result_holder:
1484+
# The task already exited and cleaned up its exit stack.
1485+
# Make sure the task itself is fully done.
1486+
try:
1487+
await task
1488+
except BaseException:
1489+
pass
1490+
return JSONResponse(
1491+
status_code=400,
1492+
content={
1493+
"detail": f"Could not connect to the MCP: {result_holder['error']!s}"
1494+
},
1495+
)
14201496

1497+
mcp_client_session = cast("ClientSession", result_holder["client"])
1498+
1499+
# Call the user callback
1500+
if config.code.on_mcp_connect:
1501+
try:
1502+
await config.code.on_mcp_connect(mcp_connection, mcp_client_session)
14211503
except Exception as e:
1422-
raise HTTPException(
1504+
# Callback failed — tear down the connection.
1505+
stop_event.set()
1506+
try:
1507+
await task
1508+
except BaseException:
1509+
pass
1510+
return JSONResponse(
14231511
status_code=400,
1424-
detail=f"Could not connect to the MCP: {e!s}",
1512+
content={"detail": f"Could not connect to the MCP: {e!s}"},
14251513
)
1426-
else:
1427-
raise HTTPException(
1428-
status_code=400,
1429-
detail="This app does not support MCP.",
1430-
)
14311514

1432-
tool_list = await mcp_session.list_tools()
1515+
# Store the session
1516+
mcp_session_obj = McpSession(
1517+
name=mcp_connection.name,
1518+
client=mcp_client_session,
1519+
task=task,
1520+
stop_event=stop_event,
1521+
)
1522+
session.mcp_sessions[mcp_connection.name] = mcp_session_obj
1523+
1524+
tool_list = await mcp_client_session.list_tools()
14331525

14341526
return JSONResponse(
14351527
content={
@@ -1475,17 +1567,11 @@ async def disconnect_mcp(
14751567

14761568
callback = config.code.on_mcp_disconnect
14771569
if payload.name in session.mcp_sessions:
1570+
mcp_session_obj = session.mcp_sessions.pop(payload.name)
14781571
try:
1479-
client_session, exit_stack = session.mcp_sessions[payload.name]
14801572
if callback:
1481-
await callback(payload.name, client_session)
1482-
1483-
try:
1484-
await exit_stack.aclose()
1485-
except Exception:
1486-
pass
1487-
del session.mcp_sessions[payload.name]
1488-
1573+
await callback(payload.name, mcp_session_obj.client)
1574+
await mcp_session_obj.close()
14891575
except Exception as e:
14901576
raise HTTPException(
14911577
status_code=400,

0 commit comments

Comments
 (0)