@@ -1290,6 +1290,8 @@ async def connect_mcp(
12901290 payload : ConnectMCPRequest ,
12911291 current_user : UserParam ,
12921292):
1293+ import asyncio
1294+
12931295 from mcp import ClientSession
12941296 from mcp .client .sse import sse_client
12951297 from mcp .client .stdio import (
@@ -1307,7 +1309,7 @@ async def connect_mcp(
13071309 StdioMcpConnection ,
13081310 validate_mcp_command ,
13091311 )
1310- from chainlit .session import WebsocketSession
1312+ from chainlit .session import McpSession , WebsocketSession
13111313
13121314 session = WebsocketSession .get_by_id (payload .sessionId )
13131315 context = init_ws_context (session )
@@ -1323,113 +1325,203 @@ async def connect_mcp(
13231325 )
13241326
13251327 mcp_enabled = config .features .mcp .enabled
1326- if mcp_enabled :
1327- if payload .name in session .mcp_sessions :
1328- old_client_session , old_exit_stack = session .mcp_sessions [payload .name ]
1329- if on_mcp_disconnect := config .code .on_mcp_disconnect :
1330- await on_mcp_disconnect (payload .name , old_client_session )
1328+ if not mcp_enabled :
1329+ raise HTTPException (
1330+ status_code = 400 ,
1331+ detail = "This app does not support MCP." ,
1332+ )
1333+
1334+ # Disconnect previous session for this name (reconnection)
1335+ if payload .name in session .mcp_sessions :
1336+ old_mcp = session .mcp_sessions .pop (payload .name )
1337+ if on_mcp_disconnect := config .code .on_mcp_disconnect :
13311338 try :
1332- await old_exit_stack . aclose ( )
1339+ await on_mcp_disconnect ( payload . name , old_mcp . client )
13331340 except Exception :
1334- pass
1335-
1341+ logger .debug (
1342+ "Error in on_mcp_disconnect callback for %s" ,
1343+ payload .name ,
1344+ exc_info = True ,
1345+ )
13361346 try :
1337- exit_stack = AsyncExitStack ()
1338- mcp_connection : McpConnection
1339-
1340- if payload .clientType == "sse" :
1341- if not config .features .mcp .sse .enabled :
1342- raise HTTPException (
1343- status_code = 400 ,
1344- detail = "SSE MCP is not enabled" ,
1345- )
1347+ await old_mcp .close ()
1348+ except Exception :
1349+ logger .debug (
1350+ "Error closing old MCP session %s" , payload .name , exc_info = True
1351+ )
13461352
1347- mcp_connection = SseMcpConnection (
1348- url = payload .url ,
1349- name = payload .name ,
1350- headers = getattr (payload , "headers" , None ),
1351- )
1353+ # ── Validate config before launching the background task ──
1354+ mcp_connection : McpConnection
1355+
1356+ if payload .clientType == "sse" :
1357+ if not config .features .mcp .sse .enabled :
1358+ raise HTTPException (
1359+ status_code = 400 ,
1360+ detail = "SSE MCP is not enabled" ,
1361+ )
1362+ mcp_connection = SseMcpConnection (
1363+ url = payload .url ,
1364+ name = payload .name ,
1365+ headers = getattr (payload , "headers" , None ),
1366+ )
1367+ elif payload .clientType == "stdio" :
1368+ if not config .features .mcp .stdio .enabled :
1369+ raise HTTPException (
1370+ status_code = 400 ,
1371+ detail = "Stdio MCP is not enabled" ,
1372+ )
1373+ env_from_cmd , command , args = validate_mcp_command (payload .fullCommand )
1374+ mcp_connection = StdioMcpConnection (
1375+ command = command , args = args , name = payload .name
1376+ )
1377+ elif payload .clientType == "streamable-http" :
1378+ if not config .features .mcp .streamable_http .enabled :
1379+ raise HTTPException (
1380+ status_code = 400 ,
1381+ detail = "HTTP MCP is not enabled" ,
1382+ )
1383+ mcp_connection = HttpMcpConnection (
1384+ url = payload .url ,
1385+ name = payload .name ,
1386+ headers = getattr (payload , "headers" , None ),
1387+ )
1388+ else :
1389+ raise HTTPException (
1390+ status_code = 400 ,
1391+ detail = f"Unknown MCP client type: { payload .clientType } " ,
1392+ )
13521393
1353- transport = await exit_stack .enter_async_context (
1354- sse_client (
1355- url = mcp_connection .url ,
1356- headers = mcp_connection .headers ,
1394+ # ── Launch the MCP connection in its own background task ──
1395+ #
1396+ # The background task owns the AsyncExitStack: it enters all context
1397+ # managers, calls initialize(), signals ``ready_event``, and then
1398+ # blocks on ``stop_event.wait()``. When the stop event fires the
1399+ # task wakes up and closes the exit stack *in the same task* that
1400+ # opened it — avoiding the cross-task cancel-scope corruption from
1401+ # https://github.com/Chainlit/chainlit/issues/2182.
1402+
1403+ ready_event : asyncio .Event = asyncio .Event ()
1404+ stop_event : asyncio .Event = asyncio .Event ()
1405+ # Mutable container to pass the ClientSession back from the bg task.
1406+ result_holder : dict [str , object ] = {}
1407+
1408+ async def _mcp_session_runner () -> None :
1409+ exit_stack = AsyncExitStack ()
1410+ try :
1411+ try :
1412+ if isinstance (mcp_connection , SseMcpConnection ):
1413+ transport = await exit_stack .enter_async_context (
1414+ sse_client (
1415+ url = mcp_connection .url ,
1416+ headers = mcp_connection .headers ,
1417+ )
13571418 )
1358- )
1359- elif payload .clientType == "stdio" :
1360- if not config .features .mcp .stdio .enabled :
1361- raise HTTPException (
1362- status_code = 400 ,
1363- detail = "Stdio MCP is not enabled" ,
1419+ elif isinstance (mcp_connection , StdioMcpConnection ):
1420+ env = get_default_environment ()
1421+ env .update (env_from_cmd )
1422+ server_params = StdioServerParameters (
1423+ command = command , args = args , env = env
13641424 )
1425+ transport = await exit_stack .enter_async_context (
1426+ stdio_client (server_params )
1427+ )
1428+ elif isinstance (mcp_connection , HttpMcpConnection ):
1429+ transport = await exit_stack .enter_async_context (
1430+ streamablehttp_client (
1431+ url = mcp_connection .url ,
1432+ headers = mcp_connection .headers ,
1433+ )
1434+ )
1435+ else :
1436+ raise ValueError (f"Unknown client type: { payload .clientType } " )
13651437
1366- env_from_cmd , command , args = validate_mcp_command (payload .fullCommand )
1367- mcp_connection = StdioMcpConnection (
1368- command = command , args = args , name = payload .name
1369- )
1370-
1371- env = get_default_environment ()
1372- env .update (env_from_cmd )
1373- # Create the server parameters
1374- server_params = StdioServerParameters (
1375- command = command , args = args , env = env
1376- )
1377-
1378- transport = await exit_stack .enter_async_context (
1379- stdio_client (server_params )
1380- )
1438+ read , write = transport [:2 ]
13811439
1382- elif payload .clientType == "streamable-http" :
1383- if not config .features .mcp .streamable_http .enabled :
1384- raise HTTPException (
1385- status_code = 400 ,
1386- detail = "HTTP MCP is not enabled" ,
1387- )
1388- mcp_connection = HttpMcpConnection (
1389- url = payload .url ,
1390- name = payload .name ,
1391- headers = getattr (payload , "headers" , None ),
1392- )
1393- transport = await exit_stack .enter_async_context (
1394- streamablehttp_client (
1395- url = mcp_connection .url ,
1396- headers = mcp_connection .headers ,
1440+ mcp_client : ClientSession = await exit_stack .enter_async_context (
1441+ ClientSession (
1442+ read_stream = read ,
1443+ write_stream = write ,
1444+ sampling_callback = None ,
13971445 )
13981446 )
13991447
1400- # The transport can return (read, write) for stdio, sse
1401- # Or (read, write, get_session_id) for streamable-http
1402- # We are only interested in the read and write streams here.
1403- read , write = transport [:2 ]
1448+ await mcp_client .initialize ()
1449+ result_holder ["client" ] = mcp_client
14041450
1405- mcp_session : ClientSession = await exit_stack .enter_async_context (
1406- ClientSession (
1407- read_stream = read , write_stream = write , sampling_callback = None
1451+ except BaseException as exc :
1452+ result_holder ["error" ] = exc
1453+ return # outer finally closes exit_stack
1454+ finally :
1455+ # Always signal the caller so it doesn't wait forever.
1456+ ready_event .set ()
1457+
1458+ # ── Keep the task (and the exit stack) alive ──
1459+ try :
1460+ await stop_event .wait ()
1461+ except asyncio .CancelledError :
1462+ logger .debug ("MCP background task for %r cancelled" , payload .name )
1463+ finally :
1464+ # Close exit_stack in ALL paths (error, normal shutdown,
1465+ # cancellation) — always in the same task that opened it.
1466+ logger .debug ("Closing MCP exit stack for %r (same-task)" , payload .name )
1467+ try :
1468+ await exit_stack .aclose ()
1469+ except BaseException :
1470+ logger .debug (
1471+ "Error closing MCP exit stack for %r" ,
1472+ payload .name ,
1473+ exc_info = True ,
14081474 )
1409- )
14101475
1411- # Initialize the session
1412- await mcp_session .initialize ()
1476+ task = asyncio .create_task (
1477+ _mcp_session_runner (), name = f"mcp-session-{ payload .name } "
1478+ )
14131479
1414- # Store the session
1415- session . mcp_sessions [ mcp_connection . name ] = ( mcp_session , exit_stack )
1480+ # Wait for the background task to finish initialisation.
1481+ await ready_event . wait ( )
14161482
1417- # Call the callback
1418- if config .code .on_mcp_connect :
1419- await config .code .on_mcp_connect (mcp_connection , mcp_session )
1483+ if "error" in result_holder :
1484+ # The task already exited and cleaned up its exit stack.
1485+ # Make sure the task itself is fully done.
1486+ try :
1487+ await task
1488+ except BaseException :
1489+ pass
1490+ return JSONResponse (
1491+ status_code = 400 ,
1492+ content = {
1493+ "detail" : f"Could not connect to the MCP: { result_holder ['error' ]!s} "
1494+ },
1495+ )
14201496
1497+ mcp_client_session = cast ("ClientSession" , result_holder ["client" ])
1498+
1499+ # Call the user callback
1500+ if config .code .on_mcp_connect :
1501+ try :
1502+ await config .code .on_mcp_connect (mcp_connection , mcp_client_session )
14211503 except Exception as e :
1422- raise HTTPException (
1504+ # Callback failed — tear down the connection.
1505+ stop_event .set ()
1506+ try :
1507+ await task
1508+ except BaseException :
1509+ pass
1510+ return JSONResponse (
14231511 status_code = 400 ,
1424- detail = f"Could not connect to the MCP: { e !s} " ,
1512+ content = { "detail" : f"Could not connect to the MCP: { e !s} " } ,
14251513 )
1426- else :
1427- raise HTTPException (
1428- status_code = 400 ,
1429- detail = "This app does not support MCP." ,
1430- )
14311514
1432- tool_list = await mcp_session .list_tools ()
1515+ # Store the session
1516+ mcp_session_obj = McpSession (
1517+ name = mcp_connection .name ,
1518+ client = mcp_client_session ,
1519+ task = task ,
1520+ stop_event = stop_event ,
1521+ )
1522+ session .mcp_sessions [mcp_connection .name ] = mcp_session_obj
1523+
1524+ tool_list = await mcp_client_session .list_tools ()
14331525
14341526 return JSONResponse (
14351527 content = {
@@ -1475,17 +1567,11 @@ async def disconnect_mcp(
14751567
14761568 callback = config .code .on_mcp_disconnect
14771569 if payload .name in session .mcp_sessions :
1570+ mcp_session_obj = session .mcp_sessions .pop (payload .name )
14781571 try :
1479- client_session , exit_stack = session .mcp_sessions [payload .name ]
14801572 if callback :
1481- await callback (payload .name , client_session )
1482-
1483- try :
1484- await exit_stack .aclose ()
1485- except Exception :
1486- pass
1487- del session .mcp_sessions [payload .name ]
1488-
1573+ await callback (payload .name , mcp_session_obj .client )
1574+ await mcp_session_obj .close ()
14891575 except Exception as e :
14901576 raise HTTPException (
14911577 status_code = 400 ,
0 commit comments