|
1 | 1 | import logging |
2 | 2 | import os |
| 3 | +import subprocess |
3 | 4 | import sys |
4 | 5 | from contextlib import asynccontextmanager |
5 | 6 | from pathlib import Path |
|
48 | 49 | PROCESS_TERMINATION_TIMEOUT = 2.0 |
49 | 50 |
|
50 | 51 |
|
| 52 | +def _is_jupyter_notebook() -> bool: |
| 53 | + """ |
| 54 | + Detect if running in a Jupyter notebook or IPython environment. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + bool: True if running in Jupyter/IPython, False otherwise |
| 58 | + """ |
| 59 | + try: |
| 60 | + from IPython import get_ipython # type: ignore[import-not-found] |
| 61 | + |
| 62 | + ipython = get_ipython() # type: ignore[no-untyped-call] |
| 63 | + return ipython is not None and ipython.__class__.__name__ in ("ZMQInteractiveShell", "TerminalInteractiveShell") |
| 64 | + except ImportError: |
| 65 | + return False |
| 66 | + |
| 67 | + |
| 68 | +def _print_stderr(line: str, errlog: TextIO) -> None: |
| 69 | + """ |
| 70 | + Print stderr output, using IPython's display system if in Jupyter notebook. |
| 71 | +
|
| 72 | + Args: |
| 73 | + line: The line to print |
| 74 | + errlog: The fallback TextIO stream (used when not in Jupyter) |
| 75 | + """ |
| 76 | + if _is_jupyter_notebook(): |
| 77 | + try: |
| 78 | + from IPython.display import HTML, display # type: ignore[import-not-found] |
| 79 | + |
| 80 | + # Use IPython's display system with red color for stderr |
| 81 | + # This ensures proper rendering in Jupyter notebooks |
| 82 | + display(HTML(f'<pre style="color: red;">{line}</pre>')) # type: ignore[no-untyped-call] |
| 83 | + except Exception: |
| 84 | + # If IPython display fails, fall back to regular print |
| 85 | + # Log the error but continue (non-critical) |
| 86 | + logger.debug("Failed to use IPython display for stderr, falling back to print", exc_info=True) |
| 87 | + print(line, file=errlog) |
| 88 | + else: |
| 89 | + # Not in Jupyter, use standard stderr redirection |
| 90 | + print(line, file=errlog) |
| 91 | + |
| 92 | + |
51 | 93 | def get_default_environment() -> dict[str, str]: |
52 | 94 | """ |
53 | 95 | Returns a default environment object including only environment variables deemed |
@@ -102,11 +144,121 @@ class StdioServerParameters(BaseModel): |
102 | 144 | """ |
103 | 145 |
|
104 | 146 |
|
| 147 | +async def _stdout_reader( |
| 148 | + process: Process | FallbackProcess, |
| 149 | + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception], |
| 150 | + encoding: str, |
| 151 | + encoding_error_handler: str, |
| 152 | +): |
| 153 | + """Read stdout from the process and parse JSONRPC messages.""" |
| 154 | + assert process.stdout, "Opened process is missing stdout" |
| 155 | + |
| 156 | + try: |
| 157 | + async with read_stream_writer: |
| 158 | + buffer = "" |
| 159 | + async for chunk in TextReceiveStream( |
| 160 | + process.stdout, |
| 161 | + encoding=encoding, |
| 162 | + errors=encoding_error_handler, |
| 163 | + ): |
| 164 | + lines = (buffer + chunk).split("\n") |
| 165 | + buffer = lines.pop() |
| 166 | + |
| 167 | + for line in lines: |
| 168 | + try: |
| 169 | + message = types.JSONRPCMessage.model_validate_json(line) |
| 170 | + except Exception as exc: # pragma: no cover |
| 171 | + logger.exception("Failed to parse JSONRPC message from server") |
| 172 | + await read_stream_writer.send(exc) |
| 173 | + continue |
| 174 | + |
| 175 | + session_message = SessionMessage(message) |
| 176 | + await read_stream_writer.send(session_message) |
| 177 | + except anyio.ClosedResourceError: # pragma: no cover |
| 178 | + await anyio.lowlevel.checkpoint() |
| 179 | + |
| 180 | + |
| 181 | +async def _stdin_writer( |
| 182 | + process: Process | FallbackProcess, |
| 183 | + write_stream_reader: MemoryObjectReceiveStream[SessionMessage], |
| 184 | + encoding: str, |
| 185 | + encoding_error_handler: str, |
| 186 | +): |
| 187 | + """Write session messages to the process stdin.""" |
| 188 | + assert process.stdin, "Opened process is missing stdin" |
| 189 | + |
| 190 | + try: |
| 191 | + async with write_stream_reader: |
| 192 | + async for session_message in write_stream_reader: |
| 193 | + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) |
| 194 | + await process.stdin.send( |
| 195 | + (json + "\n").encode( |
| 196 | + encoding=encoding, |
| 197 | + errors=encoding_error_handler, |
| 198 | + ) |
| 199 | + ) |
| 200 | + except anyio.ClosedResourceError: # pragma: no cover |
| 201 | + await anyio.lowlevel.checkpoint() |
| 202 | + |
| 203 | + |
| 204 | +async def _stderr_reader( |
| 205 | + process: Process | FallbackProcess, |
| 206 | + errlog: TextIO, |
| 207 | + encoding: str, |
| 208 | + encoding_error_handler: str, |
| 209 | +): |
| 210 | + """Read stderr from the process and display it appropriately.""" |
| 211 | + if not process.stderr: |
| 212 | + return |
| 213 | + |
| 214 | + try: |
| 215 | + buffer = "" |
| 216 | + async for chunk in TextReceiveStream( |
| 217 | + process.stderr, |
| 218 | + encoding=encoding, |
| 219 | + errors=encoding_error_handler, |
| 220 | + ): |
| 221 | + lines = (buffer + chunk).split("\n") |
| 222 | + buffer = lines.pop() |
| 223 | + |
| 224 | + for line in lines: |
| 225 | + if line.strip(): # Only print non-empty lines |
| 226 | + try: |
| 227 | + _print_stderr(line, errlog) |
| 228 | + except Exception: |
| 229 | + # Log errors but continue (non-critical) |
| 230 | + logger.debug("Failed to print stderr line", exc_info=True) |
| 231 | + |
| 232 | + # Print any remaining buffer content |
| 233 | + if buffer.strip(): |
| 234 | + try: |
| 235 | + _print_stderr(buffer, errlog) |
| 236 | + except Exception: |
| 237 | + logger.debug("Failed to print final stderr buffer", exc_info=True) |
| 238 | + except anyio.ClosedResourceError: # pragma: no cover |
| 239 | + await anyio.lowlevel.checkpoint() |
| 240 | + except Exception: |
| 241 | + # Log errors but continue (non-critical) |
| 242 | + logger.debug("Error reading stderr", exc_info=True) |
| 243 | + |
| 244 | + |
105 | 245 | @asynccontextmanager |
106 | 246 | async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): |
107 | 247 | """ |
108 | 248 | Client transport for stdio: this will connect to a server by spawning a |
109 | 249 | process and communicating with it over stdin/stdout. |
| 250 | +
|
| 251 | + This function automatically handles stderr output in a way that is compatible |
| 252 | + with Jupyter notebook environments. When running in Jupyter, stderr output |
| 253 | + is displayed using IPython's display system with red color formatting. |
| 254 | + When not in Jupyter, stderr is redirected to the provided errlog stream |
| 255 | + (defaults to sys.stderr). |
| 256 | +
|
| 257 | + Args: |
| 258 | + server: Parameters for the server process to spawn |
| 259 | + errlog: TextIO stream for stderr output when not in Jupyter (defaults to sys.stderr). |
| 260 | + This parameter is kept for backward compatibility but may be ignored |
| 261 | + when running in Jupyter notebook environments. |
110 | 262 | """ |
111 | 263 | read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] |
112 | 264 | read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] |
@@ -136,55 +288,14 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder |
136 | 288 | await write_stream_reader.aclose() |
137 | 289 | raise |
138 | 290 |
|
139 | | - async def stdout_reader(): |
140 | | - assert process.stdout, "Opened process is missing stdout" |
141 | | - |
142 | | - try: |
143 | | - async with read_stream_writer: |
144 | | - buffer = "" |
145 | | - async for chunk in TextReceiveStream( |
146 | | - process.stdout, |
147 | | - encoding=server.encoding, |
148 | | - errors=server.encoding_error_handler, |
149 | | - ): |
150 | | - lines = (buffer + chunk).split("\n") |
151 | | - buffer = lines.pop() |
152 | | - |
153 | | - for line in lines: |
154 | | - try: |
155 | | - message = types.JSONRPCMessage.model_validate_json(line) |
156 | | - except Exception as exc: # pragma: no cover |
157 | | - logger.exception("Failed to parse JSONRPC message from server") |
158 | | - await read_stream_writer.send(exc) |
159 | | - continue |
160 | | - |
161 | | - session_message = SessionMessage(message) |
162 | | - await read_stream_writer.send(session_message) |
163 | | - except anyio.ClosedResourceError: # pragma: no cover |
164 | | - await anyio.lowlevel.checkpoint() |
165 | | - |
166 | | - async def stdin_writer(): |
167 | | - assert process.stdin, "Opened process is missing stdin" |
168 | | - |
169 | | - try: |
170 | | - async with write_stream_reader: |
171 | | - async for session_message in write_stream_reader: |
172 | | - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) |
173 | | - await process.stdin.send( |
174 | | - (json + "\n").encode( |
175 | | - encoding=server.encoding, |
176 | | - errors=server.encoding_error_handler, |
177 | | - ) |
178 | | - ) |
179 | | - except anyio.ClosedResourceError: # pragma: no cover |
180 | | - await anyio.lowlevel.checkpoint() |
181 | | - |
182 | 291 | async with ( |
183 | 292 | anyio.create_task_group() as tg, |
184 | 293 | process, |
185 | 294 | ): |
186 | | - tg.start_soon(stdout_reader) |
187 | | - tg.start_soon(stdin_writer) |
| 295 | + tg.start_soon(_stdout_reader, process, read_stream_writer, server.encoding, server.encoding_error_handler) |
| 296 | + tg.start_soon(_stdin_writer, process, write_stream_reader, server.encoding, server.encoding_error_handler) |
| 297 | + if process.stderr: |
| 298 | + tg.start_soon(_stderr_reader, process, errlog, server.encoding, server.encoding_error_handler) |
188 | 299 | try: |
189 | 300 | yield read_stream, write_stream |
190 | 301 | finally: |
@@ -244,14 +355,19 @@ async def _create_platform_compatible_process( |
244 | 355 |
|
245 | 356 | Unix: Creates process in a new session/process group for killpg support |
246 | 357 | Windows: Creates process in a Job Object for reliable child termination |
| 358 | +
|
| 359 | + Note: stderr is piped (not redirected) to allow async reading for Jupyter |
| 360 | + notebook compatibility. The errlog parameter is kept for backward compatibility |
| 361 | + but is only used when not in Jupyter environments. |
247 | 362 | """ |
248 | 363 | if sys.platform == "win32": # pragma: no cover |
249 | | - process = await create_windows_process(command, args, env, errlog, cwd) |
| 364 | + process = await create_windows_process(command, args, env, errlog, cwd, pipe_stderr=True) |
250 | 365 | else: |
| 366 | + # Pipe stderr instead of redirecting to allow async reading |
251 | 367 | process = await anyio.open_process( |
252 | 368 | [command, *args], |
253 | 369 | env=env, |
254 | | - stderr=errlog, |
| 370 | + stderr=subprocess.PIPE, |
255 | 371 | cwd=cwd, |
256 | 372 | start_new_session=True, |
257 | 373 | ) # pragma: no cover |
|
0 commit comments