Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions marimo/_messaging/tracebacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ def _highlight_traceback(traceback: str) -> str:


def write_traceback(traceback: str) -> None:
if isinstance(sys.stderr, Stderr):
sys.stderr._write_with_mimetype(
_highlight_traceback(_trim_traceback(traceback)),
mimetype="application/vnd.marimo+traceback",
)
else:
# Short-circuit: avoid unnecessary function call if not custom stderr
if not isinstance(sys.stderr, Stderr):
sys.stderr.write(traceback)
return
sys.stderr._write_with_mimetype(
_highlight_traceback(_trim_traceback(traceback)),
mimetype="application/vnd.marimo+traceback",
)


def _trim_traceback(traceback: str) -> str:
Expand Down
28 changes: 11 additions & 17 deletions marimo/_runtime/reload/autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,7 @@ def append_obj(
name: str,
obj: object,
) -> bool:
in_module = (
hasattr(obj, "__module__") and obj.__module__ == module.__name__
)
if not in_module:
if getattr(obj, "__module__", None) != module.__name__:
return False

key = (module.__name__, name)
Expand All @@ -431,15 +428,11 @@ def superreload(
if old_objects is None:
old_objects = {}

# collect old objects in the module
for name, obj in list(module.__dict__.items()):
if not append_obj(module, old_objects, name, obj):
continue
key = (module.__name__, name)
try:
old_objects.setdefault(key, []).append(weakref.ref(obj))
except TypeError:
pass
# Collect old objects in the module (append_obj already adds weakref, avoid duplication here)
for name, obj in module.__dict__.items():
append_obj(module, old_objects, name, obj)

# reload module

# reload module
old_dict: dict[str, Any] | None = None
Expand All @@ -449,7 +442,7 @@ def superreload(
old_name = module.__name__
module.__dict__.clear()
module.__dict__["__name__"] = old_name
module.__dict__["__loader__"] = old_dict["__loader__"]
module.__dict__["__loader__"] = old_dict.get("__loader__")
except (TypeError, AttributeError, KeyError):
pass

Expand All @@ -472,13 +465,14 @@ def superreload(
raise

# iterate over all objects and update functions & classes
for name, new_obj in list(module.__dict__.items()):
for name, new_obj in module.__dict__.items():
key = (module.__name__, name)
if key not in old_objects:
old_refs = old_objects.get(key)
if old_refs is None:
continue

new_refs = []
for old_ref in old_objects[key]:
for old_ref in old_refs:
old_obj = old_ref()
if old_obj is None:
continue
Expand Down