diff --git a/marimo/_messaging/tracebacks.py b/marimo/_messaging/tracebacks.py index 0e1a7019c4f..3e17617ad60 100644 --- a/marimo/_messaging/tracebacks.py +++ b/marimo/_messaging/tracebacks.py @@ -22,13 +22,14 @@ def _highlight_traceback(traceback: str) -> str: def write_traceback(traceback: str) -> None: - if isinstance(sys.stderr, Stderr): - sys.stderr._write_with_mimetype( - _highlight_traceback(_trim_traceback(traceback)), - mimetype="application/vnd.marimo+traceback", - ) - else: + # Short-circuit: avoid unnecessary function call if not custom stderr + if not isinstance(sys.stderr, Stderr): sys.stderr.write(traceback) + return + sys.stderr._write_with_mimetype( + _highlight_traceback(_trim_traceback(traceback)), + mimetype="application/vnd.marimo+traceback", + ) def _trim_traceback(traceback: str) -> str: diff --git a/marimo/_runtime/reload/autoreload.py b/marimo/_runtime/reload/autoreload.py index 61a1e4ced31..76c05afd9b2 100644 --- a/marimo/_runtime/reload/autoreload.py +++ b/marimo/_runtime/reload/autoreload.py @@ -402,10 +402,7 @@ def append_obj( name: str, obj: object, ) -> bool: - in_module = ( - hasattr(obj, "__module__") and obj.__module__ == module.__name__ - ) - if not in_module: + if getattr(obj, "__module__", None) != module.__name__: return False key = (module.__name__, name) @@ -431,15 +428,11 @@ def superreload( if old_objects is None: old_objects = {} - # collect old objects in the module - for name, obj in list(module.__dict__.items()): - if not append_obj(module, old_objects, name, obj): - continue - key = (module.__name__, name) - try: - old_objects.setdefault(key, []).append(weakref.ref(obj)) - except TypeError: - pass + # Collect old objects in the module (append_obj already adds weakref, avoid duplication here) + for name, obj in module.__dict__.items(): + append_obj(module, old_objects, name, obj) + + # reload module # reload module old_dict: dict[str, Any] | None = None @@ -449,7 +442,7 @@ def superreload( old_name = module.__name__ module.__dict__.clear() module.__dict__["__name__"] = old_name - module.__dict__["__loader__"] = old_dict["__loader__"] + module.__dict__["__loader__"] = old_dict.get("__loader__") except (TypeError, AttributeError, KeyError): pass @@ -472,13 +465,14 @@ def superreload( raise # iterate over all objects and update functions & classes - for name, new_obj in list(module.__dict__.items()): + for name, new_obj in module.__dict__.items(): key = (module.__name__, name) - if key not in old_objects: + old_refs = old_objects.get(key) + if old_refs is None: continue new_refs = [] - for old_ref in old_objects[key]: + for old_ref in old_refs: old_obj = old_ref() if old_obj is None: continue