From bb590e359215dd2fdfaa51e413d4cba30fc52f56 Mon Sep 17 00:00:00 2001 From: Kang Li Date: Sun, 15 Mar 2026 20:05:21 +0800 Subject: [PATCH 1/5] Add Image transform and vendor imports Introduce an Image TransformMedia implementation and wire it into the transforms registry. The new Image class accepts file paths, PIL images, numpy arrays, torch tensors, and matplotlib figures, supports optional resizing and output format validation, serializes to an image buffer, and implements column_type/build_data_record/transform (writes file with sha256-based name via safe_write). Export Image from transforms.__init__ and update tests to include an Image media factory. Also extend swanlab.vendor lazy imports to expose torch, torchvision and ensure PIL.Image submodule is imported so PIL.Image usage works with the lazy loader. --- .../sdk/internal/run/transforms/__init__.py | 3 +- .../internal/run/transforms/image/__init__.py | 163 +++++++++++++++++- swanlab/vendor/__init__.py | 19 +- .../internal/run/data/test_media_transform.py | 2 + 4 files changed, 184 insertions(+), 3 deletions(-) diff --git a/swanlab/sdk/internal/run/transforms/__init__.py b/swanlab/sdk/internal/run/transforms/__init__.py index 9c195f6db..b17f066ac 100644 --- a/swanlab/sdk/internal/run/transforms/__init__.py +++ b/swanlab/sdk/internal/run/transforms/__init__.py @@ -10,10 +10,11 @@ from swanlab.sdk.internal.context import TransformMedia from .audio import Audio +from .image import Image from .scalar import Scalar from .text import Text -__all__ = ["Text", "Scalar", "Audio", "normalize_media_input"] +__all__ = ["Text", "Scalar", "Audio", "Image", "normalize_media_input"] def normalize_media_input( diff --git a/swanlab/sdk/internal/run/transforms/image/__init__.py b/swanlab/sdk/internal/run/transforms/image/__init__.py index 1f61c7499..a4112495c 100644 --- a/swanlab/sdk/internal/run/transforms/image/__init__.py +++ b/swanlab/sdk/internal/run/transforms/image/__init__.py @@ -1,6 +1,167 @@ """ @author: cunyue @file: __init__.py -@time: 2026/3/11 19:17 +@time: 2026/3/15 @description: 图像处理模块 """ + +import hashlib +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Union + +from google.protobuf.timestamp_pb2 import Timestamp + +from swanlab import vendor +from swanlab.proto.swanlab.metric.column.v1.column_pb2 import ColumnType +from swanlab.proto.swanlab.metric.data.v1.data_pb2 import DataRecord +from swanlab.proto.swanlab.metric.data.v1.media.image_pb2 import ImageItem, ImageValue +from swanlab.sdk.internal.context import TransformMedia +from swanlab.sdk.internal.pkg.fs import safe_write + +ACCEPT_FORMAT = ["png", "jpg", "jpeg", "bmp"] + + +def _is_torch_tensor(obj) -> bool: + """通过类型名检测 PyTorch Tensor,避免强制导入 torch""" + typename = obj.__class__.__module__ + "." + obj.__class__.__name__ + return typename.startswith("torch.") and ("Tensor" in typename or "Variable" in typename) + + +def _resize(image: "vendor.PIL.Image.Image", size) -> "vendor.PIL.Image.Image": + """按 size 参数缩放图像""" + if size is None: + return image + if isinstance(size, int): + if max(image.size) > size: + image.thumbnail((size, size)) + return image + if isinstance(size, (list, tuple)): + w, h = (tuple(size) + (None,))[:2] + if w is not None and h is not None: + return image.resize((int(w), int(h))) + if w is not None: + return image.resize((int(w), int(image.size[1] * w / image.size[0]))) + if h is not None: + return image.resize((int(image.size[0] * h / image.size[1]), int(h))) + raise ValueError("size must be an int, or a list/tuple with 1-2 elements") + + +class Image(TransformMedia): + def __init__( + self, + data_or_path: Union["Image", str, "vendor.PIL.Image.Image", "vendor.np.ndarray"], + mode: Optional[str] = None, + caption: Optional[str] = None, + file_type: Optional[str] = None, + size: Optional[Union[int, list, tuple]] = None, + ): + """Image class constructor + + Parameters + ---------- + data_or_path: str, PIL.Image.Image, numpy.ndarray, torch.Tensor, matplotlib figure, or Image + Path to an image file (PNG/JPG/JPEG/BMP; GIF is not supported), a PIL Image, + numpy array (shape: (H, W) or (H, W, 3/4)), torch.Tensor, matplotlib figure, + or another Image instance (nesting). + mode: str, optional + PIL mode applied when converting to PIL.Image (e.g. 'RGB', 'L'). + caption: str, optional + Caption for the image. + file_type: str, optional + Output file format. One of ['png', 'jpg', 'jpeg', 'bmp']. Defaults to 'png'. + size: int, list, or tuple, optional + Resize policy: + - int: maximum side length (aspect-ratio preserved via thumbnail). + - (w, h): exact target size. + - (w, None) / (None, h): fix one dimension, scale the other proportionally. + - None: no resize. + """ + super().__init__() + + # 套娃加载 + attrs = self._unwrap(data_or_path) + if attrs: + self.buffer: BytesIO = attrs["buffer"] + self.file_type: str = attrs["file_type"] + self.caption: Optional[str] = caption if caption is not None else attrs.get("caption") + return + + # 校验 file_type + ft = (file_type or "png").lower() + if ft not in ACCEPT_FORMAT: + raise ValueError(f"Unsupported file_type '{ft}'. Accepted: {ACCEPT_FORMAT}") + self.file_type = ft + + # ---------- 各类型输入 → PIL Image ---------- + # 考虑到懒加载限制我们一般使用鸭子类型判断,而不是 isinstance + PILImage = vendor.PIL.Image + # 1. 文件路径 + if isinstance(data_or_path, str): + if data_or_path.lower().endswith(".gif"): + raise TypeError("GIF images are not supported. Please convert to PNG or JPG first.") + try: + pil_img = PILImage.open(data_or_path) + except Exception as e: + raise ValueError(f"Failed to open image file: {data_or_path!r}") from e + if getattr(pil_img, "format", None) == "GIF": + raise TypeError("GIF images are not supported. Please convert to PNG or JPG first.") + image_data = pil_img.convert(mode) + # 2. PIL Image + elif isinstance(data_or_path, PILImage.Image): + if getattr(data_or_path, "format", None) == "GIF": + raise TypeError("GIF images are not supported. Please convert to PNG or JPG first.") + image_data = data_or_path.convert(mode) + # 3. PyTorch Tensor + elif _is_torch_tensor(data_or_path): + t = data_or_path + if hasattr(t, "requires_grad") and t.requires_grad: # type: ignore[union-attr] + t = t.detach() # type: ignore[union-attr] + t = vendor.torchvision.utils.make_grid(t, normalize=True) # type: ignore[arg-type] + image_data = PILImage.fromarray(t.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(), mode=mode) + # 4. Matplotlib Figure + elif hasattr(data_or_path, "savefig"): + try: + buf = BytesIO() + data_or_path.savefig(buf, format=self.file_type) # type: ignore[union-attr] + buf.seek(0) + image_data = PILImage.open(buf).convert(mode) + buf.close() + except Exception as e: + raise TypeError("Failed to convert matplotlib figure to image") from e + # 5. Numpy Array + elif isinstance(data_or_path, vendor.np.ndarray): + arr = data_or_path + if arr.ndim == 2 or (arr.ndim == 3 and arr.shape[2] in (3, 4)): + image_data = PILImage.fromarray(vendor.np.clip(arr, 0, 255).astype(vendor.np.uint8), mode=mode) + else: + raise TypeError(f"Invalid numpy array shape for Image: expected (H, W) or (H, W, 3/4), got {arr.shape}") + + else: + raise TypeError( + f"Unsupported image type: {type(data_or_path).__name__}. " + "Please provide a file path, PIL.Image, numpy.ndarray, torch.Tensor, or matplotlib figure." + ) + + image_data = _resize(image_data, size) + self.buffer = BytesIO() + save_fmt = "jpeg" if self.file_type == "jpg" else self.file_type + image_data.save(self.buffer, format=save_fmt) + self.caption = caption + + @classmethod + def column_type(cls) -> ColumnType: + return ColumnType.COLUMN_TYPE_IMAGE + + @classmethod + def build_data_record(cls, *, key: str, step: int, timestamp: Timestamp, data: List[ImageItem]) -> DataRecord: + return DataRecord( + key=key, step=step, timestamp=timestamp, type=cls.column_type(), images=ImageValue(items=data) + ) + + def transform(self, *, step: int, path: Path) -> ImageItem: + content = self.buffer.getvalue() + sha256 = hashlib.sha256(content).hexdigest() + filename = f"{step:03d}-{sha256[:8]}.{self.file_type}" + safe_write(path / filename, content, mode="wb") + return ImageItem(filename=filename, sha256=sha256, size=len(content), caption=self.caption or "") diff --git a/swanlab/vendor/__init__.py b/swanlab/vendor/__init__.py index dc3c3e037..3875f0899 100644 --- a/swanlab/vendor/__init__.py +++ b/swanlab/vendor/__init__.py @@ -19,9 +19,12 @@ import numpy as np import pandas as pd import PIL + import PIL.Image import rdkit import soundfile import swanboard + import torch + import torchvision # 2. Expose the available modules for IDE auto-completion @@ -35,6 +38,8 @@ "soundfile", "swanboard", "boto3", + "torch", + "torchvision", # these are extra dependencies which are not in [project.optional-dependencies] "pd", ] @@ -50,6 +55,8 @@ "soundfile": "soundfile", "swanboard": "swanboard", "boto3": "boto3", + "torch": "torch", + "torchvision": "torchvision", # these are extra dependencies which are not in [project.optional-dependencies] "pd": "pandas", } @@ -71,8 +78,14 @@ "boto3": "s3", } +# 5. Submodule imports: some packages require submodules to be imported explicitly +# so their attributes are accessible (e.g. PIL.Image must be imported for PIL.Image to work) +_SUBMODULE_IMPORTS = { + "PIL": ["PIL.Image"], +} + -# 5. Module-level __getattr__ for lazy loading (PEP 562) +# 6. Module-level __getattr__ for lazy loading (PEP 562) def __getattr__(name: str) -> Any: if name in _LAZY_IMPORTS: module_path = _LAZY_IMPORTS[name] @@ -86,6 +99,10 @@ def __getattr__(name: str) -> Any: # Handle direct third-party library imports obj = importlib.import_module(module_path) + # Import required submodules so their attributes are accessible on the parent package + for submodule_path in _SUBMODULE_IMPORTS.get(name, []): + importlib.import_module(submodule_path) + # Cache the imported object in the module's global namespace globals()[name] = obj return obj diff --git a/tests/unit/sdk/internal/run/data/test_media_transform.py b/tests/unit/sdk/internal/run/data/test_media_transform.py index 3ff3d9b46..0cda5fa51 100644 --- a/tests/unit/sdk/internal/run/data/test_media_transform.py +++ b/tests/unit/sdk/internal/run/data/test_media_transform.py @@ -18,11 +18,13 @@ import swanlab.sdk.internal.run.transforms # noqa: F401 # type: ignore — 触发所有子类注册 from swanlab.sdk.internal.context import TransformMedia from swanlab.sdk.internal.run.transforms.audio import Audio +from swanlab.sdk.internal.run.transforms.image import Image from swanlab.sdk.internal.run.transforms.text import Text # 注册表:TransformMedia 子类 → 无参工厂(每次调用返回内容相同的新实例) MEDIA_FACTORIES = { Audio: lambda: Audio(np.zeros((1, 4410), dtype=np.float32), sample_rate=44100), + Image: lambda: Image(np.zeros((10, 10, 3), dtype=np.uint8)), Text: lambda: Text(content="hello world"), } From 5febb49b010ec0ce7fe3b09ab5664a42a847b45e Mon Sep 17 00:00:00 2001 From: Kang Li Date: Sun, 15 Mar 2026 21:06:47 +0800 Subject: [PATCH 2/5] Add Video transform and media logging Introduce a Video media transform (GIF support) with format detection, safe file writing, sha256-based filenames, and build_data_record/transform implementations. Export Video from transforms and simplify normalize_media_input to always construct media instances (enables nested/wrapped media). Add SwanLabRun convenience methods: log_image, log_audio (with sample_rate), and log_video that normalize inputs and forward to log. Include unit tests for Video behavior and parameterized media logging; update media factories in existing tests. --- swanlab/sdk/internal/run/__init__.py | 74 +++++++- .../sdk/internal/run/transforms/__init__.py | 10 +- .../internal/run/transforms/video/__init__.py | 120 +++++++++++++ .../internal/run/data/test_media_transform.py | 8 + .../unit/sdk/internal/run/data/test_video.py | 170 ++++++++++++++++++ .../sdk/internal/run/test_run_log_media.py | 145 +++++++++++++++ 6 files changed, 520 insertions(+), 7 deletions(-) create mode 100644 swanlab/sdk/internal/run/transforms/video/__init__.py create mode 100644 tests/unit/sdk/internal/run/data/test_video.py create mode 100644 tests/unit/sdk/internal/run/test_run_log_media.py diff --git a/swanlab/sdk/internal/run/__init__.py b/swanlab/sdk/internal/run/__init__.py index 1d993fb6c..d9477bc91 100644 --- a/swanlab/sdk/internal/run/__init__.py +++ b/swanlab/sdk/internal/run/__init__.py @@ -28,7 +28,7 @@ create_unbound_run_config, deactivate_run_config, ) -from swanlab.sdk.internal.run.transforms import Text, normalize_media_input +from swanlab.sdk.internal.run.transforms import Audio, Image, Text, Video, normalize_media_input from swanlab.sdk.typings.run import FinishType from swanlab.sdk.typings.run.column import ScalarXAxisType @@ -257,6 +257,78 @@ def log_text( normalized_data = normalize_media_input(Text, data, caption=caption) self.log({key: normalized_data}, step=step) + @with_lock + @with_run + def log_image( + self, + key: str, + data: Union[Image, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, + ): + """ + A syntactic sugar for logging image data. + + :param key: The key for the image data. + + :param data: The image data itself or an Image object. + + :param caption: Optional caption for the image data. + + :param step: Optional step for the image data. + """ + normalized_data = normalize_media_input(Image, data, caption=caption) + self.log({key: normalized_data}, step=step) + + @with_lock + @with_run + def log_audio( + self, + key: str, + data: Union[Audio, Any, List[Any]], + sample_rate: int = 44100, + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, + ): + """ + A syntactic sugar for logging audio data. + + :param key: The key for the audio data. + + :param data: The audio data itself or an Audio object. + + :param sample_rate: Sample rate of the audio (used when data is raw numpy array). + + :param caption: Optional caption for the audio data. + + :param step: Optional step for the audio data. + """ + normalized_data = normalize_media_input(Audio, data, caption=caption, sample_rate=sample_rate) + self.log({key: normalized_data}, step=step) + + @with_lock + @with_run + def log_video( + self, + key: str, + data: Union[Video, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, + ): + """ + A syntactic sugar for logging video data. + + :param key: The key for the video data. + + :param data: The video data itself or a Video object. + + :param caption: Optional caption for the video data. + + :param step: Optional step for the video data. + """ + normalized_data = normalize_media_input(Video, data, caption=caption) + self.log({key: normalized_data}, step=step) + @with_lock @with_run def define_scalar( diff --git a/swanlab/sdk/internal/run/transforms/__init__.py b/swanlab/sdk/internal/run/transforms/__init__.py index b17f066ac..7f48ed784 100644 --- a/swanlab/sdk/internal/run/transforms/__init__.py +++ b/swanlab/sdk/internal/run/transforms/__init__.py @@ -13,8 +13,9 @@ from .image import Image from .scalar import Scalar from .text import Text +from .video import Video -__all__ = ["Text", "Scalar", "Audio", "Image", "normalize_media_input"] +__all__ = ["Text", "Scalar", "Audio", "Image", "Video", "normalize_media_input"] def normalize_media_input( @@ -57,10 +58,7 @@ def normalize_media_input( # 构造媒体对象列表 result = [] for i, item in enumerate(data_list): - if isinstance(item, media_cls): - result.append(item) - else: - item_kwargs = {k: v[i] for k, v in normalized_kwargs.items()} - result.append(media_cls(*[item], **item_kwargs)) + item_kwargs = {k: v[i] for k, v in normalized_kwargs.items()} + result.append(media_cls(*[item], **item_kwargs)) return result diff --git a/swanlab/sdk/internal/run/transforms/video/__init__.py b/swanlab/sdk/internal/run/transforms/video/__init__.py new file mode 100644 index 000000000..d282e8b1e --- /dev/null +++ b/swanlab/sdk/internal/run/transforms/video/__init__.py @@ -0,0 +1,120 @@ +""" +@author: cunyue +@file: __init__.py +@time: 2026/3/15 +@description: 视频处理模块,暂时只支持 GIF +""" + +import hashlib +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Union + +from google.protobuf.timestamp_pb2 import Timestamp + +from swanlab.proto.swanlab.metric.column.v1.column_pb2 import ColumnType +from swanlab.proto.swanlab.metric.data.v1.data_pb2 import DataRecord +from swanlab.proto.swanlab.metric.data.v1.media.video_pb2 import VideoItem, VideoValue +from swanlab.sdk.internal.context import TransformMedia +from swanlab.sdk.internal.pkg.fs import safe_write + +# 各格式的魔数校验表,新增格式时在此追加 +# format → (magic_bytes, ...) +_FORMAT_MAGIC: dict[str, tuple[bytes, ...]] = { + "gif": (b"GIF87a", b"GIF89a"), +} + +# 路径后缀 → 格式名 +_EXT_TO_FORMAT: dict[str, str] = { + ".gif": "gif", +} + + +def _detect_format_by_magic(data: bytes) -> Optional[str]: + """根据魔数推断格式,无法识别则返回 None""" + for fmt, magics in _FORMAT_MAGIC.items(): + if any(data.startswith(m) for m in magics): + return fmt + return None + + +class Video(TransformMedia): + def __init__( + self, + data_or_path: Union["Video", str, bytes, BytesIO], + caption: Optional[str] = None, + ): + """Video class constructor + + 目前支持的格式:GIF。 + + Parameters + ---------- + data_or_path: str, bytes, BytesIO, or Video + Path to a supported video file, raw video bytes, a BytesIO containing + video data, or another Video instance (nesting). + caption: str, optional + Caption for the video. + """ + super().__init__() + + # 套娃加载 + attrs = self._unwrap(data_or_path) + if attrs: + self.buffer: BytesIO = attrs["buffer"] + self.format: str = attrs["format"] + self.caption: Optional[str] = caption if caption is not None else attrs.get("caption") + return + + # 1. 文件路径 + if isinstance(data_or_path, str): + ext = Path(data_or_path).suffix.lower() + if ext not in _EXT_TO_FORMAT: + supported = ", ".join(_EXT_TO_FORMAT) + raise TypeError(f"Unsupported file extension '{ext}'. Supported: {supported}") + try: + with open(data_or_path, "rb") as f: + raw = f.read() + except OSError as e: + raise ValueError(f"Failed to open file: {data_or_path!r}") from e + fmt = _detect_format_by_magic(raw) + if fmt is None: + raise TypeError(f"File '{data_or_path}' does not match any known video format magic number.") + self.format = fmt + + # 2. bytes 或 BytesIO + elif isinstance(data_or_path, (bytes, BytesIO)): + raw = data_or_path if isinstance(data_or_path, bytes) else data_or_path.read() + fmt = _detect_format_by_magic(raw) + if fmt is None: + supported = ", ".join(_FORMAT_MAGIC) + raise TypeError(f"Cannot detect video format from bytes. Supported formats: {supported}") + self.format = fmt + + # 3. 其他类型 + else: + supported = ", ".join(_EXT_TO_FORMAT) + raise TypeError( + f"Unsupported type: {type(data_or_path).__name__}. " + f"Please provide a file path ({supported}), bytes, or BytesIO." + ) + + self.buffer = BytesIO(raw) + self.caption = caption + + @classmethod + def column_type(cls) -> ColumnType: + return ColumnType.COLUMN_TYPE_VIDEO + + @classmethod + def build_data_record(cls, *, key: str, step: int, timestamp: Timestamp, data: List[VideoItem]) -> DataRecord: + return DataRecord( + key=key, step=step, timestamp=timestamp, type=cls.column_type(), videos=VideoValue(items=data) + ) + + def transform(self, *, step: int, path: Path) -> VideoItem: + content = self.buffer.getvalue() + sha256 = hashlib.sha256(content).hexdigest() + filename = f"{step:03d}-{sha256[:8]}.{self.format}" + safe_write(path / filename, content, mode="wb") + return VideoItem(filename=filename, sha256=sha256, size=len(content), caption=self.caption or "") diff --git a/tests/unit/sdk/internal/run/data/test_media_transform.py b/tests/unit/sdk/internal/run/data/test_media_transform.py index 0cda5fa51..ab47e1d3e 100644 --- a/tests/unit/sdk/internal/run/data/test_media_transform.py +++ b/tests/unit/sdk/internal/run/data/test_media_transform.py @@ -20,12 +20,20 @@ from swanlab.sdk.internal.run.transforms.audio import Audio from swanlab.sdk.internal.run.transforms.image import Image from swanlab.sdk.internal.run.transforms.text import Text +from swanlab.sdk.internal.run.transforms.video import Video + +# 最小合法 GIF89a(1×1 像素) +_GIF_1X1 = ( + b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00" + b"!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" +) # 注册表:TransformMedia 子类 → 无参工厂(每次调用返回内容相同的新实例) MEDIA_FACTORIES = { Audio: lambda: Audio(np.zeros((1, 4410), dtype=np.float32), sample_rate=44100), Image: lambda: Image(np.zeros((10, 10, 3), dtype=np.uint8)), Text: lambda: Text(content="hello world"), + Video: lambda: Video(_GIF_1X1), } diff --git a/tests/unit/sdk/internal/run/data/test_video.py b/tests/unit/sdk/internal/run/data/test_video.py new file mode 100644 index 000000000..34cb81df8 --- /dev/null +++ b/tests/unit/sdk/internal/run/data/test_video.py @@ -0,0 +1,170 @@ +""" +@author: cunyue +@file: test_video.py +@time: 2026/3/15 +@description: 视频处理模块单元测试 +""" + +import hashlib +from io import BytesIO +from pathlib import Path + +import pytest +from google.protobuf.timestamp_pb2 import Timestamp + +from swanlab.proto.swanlab.metric.data.v1.media.video_pb2 import VideoItem +from swanlab.sdk.internal.run.transforms.video import Video + +# 最小合法 GIF89a(1×1 像素) +GIF_BYTES = ( + b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00" + b"!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" +) + + +@pytest.fixture +def gif_file(tmp_path: Path) -> str: + """写一个临时 GIF 文件,返回其路径字符串""" + path = tmp_path / "test.gif" + path.write_bytes(GIF_BYTES) + return str(path) + + +# ---------------------------------- 构造测试 ---------------------------------- + + +class TestVideoInit: + def test_from_bytes(self): + v = Video(GIF_BYTES) + assert v.format == "gif" + assert v.buffer.getvalue() == GIF_BYTES + assert v.caption is None + + def test_from_bytesio(self): + v = Video(BytesIO(GIF_BYTES)) + assert v.format == "gif" + assert v.buffer.getvalue() == GIF_BYTES + + def test_from_path(self, gif_file): + v = Video(gif_file) + assert v.format == "gif" + assert v.buffer.getvalue() == GIF_BYTES + + def test_caption_stored(self): + v = Video(GIF_BYTES, caption="my clip") + assert v.caption == "my clip" + + def test_caption_none_by_default(self): + v = Video(GIF_BYTES) + assert v.caption is None + + +# ---------------------------------- 错误处理测试 ---------------------------------- + + +class TestVideoInitErrors: + def test_invalid_extension_raises(self, tmp_path): + p = tmp_path / "clip.mp4" + p.write_bytes(b"\x00" * 16) + with pytest.raises(TypeError, match="Unsupported file extension"): + Video(str(p)) + + def test_nonexistent_path_raises(self): + with pytest.raises(ValueError, match="Failed to open file"): + Video("/nonexistent/path/clip.gif") + + def test_bad_magic_in_file_raises(self, tmp_path): + """文件扩展名合法但内容不是 GIF""" + p = tmp_path / "fake.gif" + p.write_bytes(b"\xff\xd8\xff\xe0" + b"\x00" * 16) # JPEG 魔数 + with pytest.raises(TypeError, match="magic number"): + Video(str(p)) + + def test_bad_magic_in_bytes_raises(self): + with pytest.raises(TypeError, match="Cannot detect video format"): + Video(b"\x00\x01\x02\x03\x04\x05\x06\x07") + + def test_unsupported_type_raises(self): + with pytest.raises(TypeError, match="Unsupported type"): + Video(12345) # type: ignore + + +# ---------------------------------- 套娃加载测试 ---------------------------------- + + +class TestVideoNesting: + def test_wrap_copies_buffer_and_format(self): + inner = Video(GIF_BYTES, caption="inner") + outer = Video(inner) + assert outer.buffer is inner.buffer + assert outer.format == inner.format + assert outer.caption == "inner" + + def test_outer_caption_overrides_inner(self): + inner = Video(GIF_BYTES, caption="inner") + outer = Video(inner, caption="outer") + assert outer.caption == "outer" + + def test_inner_caption_used_when_outer_none(self): + inner = Video(GIF_BYTES, caption="inner") + outer = Video(inner, caption=None) + assert outer.caption == "inner" + + +# ---------------------------------- column_type / build_data_record 测试 ---------------------------------- + + +class TestVideoColumnType: + def test_column_type(self): + from swanlab.proto.swanlab.metric.column.v1.column_pb2 import ColumnType + + assert Video.column_type() == ColumnType.COLUMN_TYPE_VIDEO + + +class TestVideoBuildDataRecord: + def test_build_data_record_structure(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + ts = Timestamp() + record = Video.build_data_record(key="rollout", step=1, timestamp=ts, data=[item]) + + assert record.key == "rollout" + assert record.step == 1 + assert len(record.videos.items) == 1 + assert record.videos.items[0].filename == item.filename + + def test_build_data_record_multiple_items(self, tmp_path): + i1 = Video(GIF_BYTES).transform(step=1, path=tmp_path) + i2 = Video(BytesIO(GIF_BYTES)).transform(step=1, path=tmp_path) + ts = Timestamp() + record = Video.build_data_record(key="k", step=1, timestamp=ts, data=[i1, i2]) + assert len(record.videos.items) == 2 + + +# ---------------------------------- transform 特有字段测试 ---------------------------------- + + +class TestVideoTransform: + def test_transform_returns_video_item(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + assert isinstance(item, VideoItem) + + def test_transform_sha256_correct(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + content = (tmp_path / item.filename).read_bytes() + assert item.sha256 == hashlib.sha256(content).hexdigest() + + def test_transform_size_correct(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + assert item.size == len((tmp_path / item.filename).read_bytes()) + + def test_transform_caption_empty_when_none(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + assert item.caption == "" + + def test_transform_caption_preserved(self, tmp_path): + item = Video(GIF_BYTES, caption="hello").transform(step=1, path=tmp_path) + assert item.caption == "hello" + + def test_transform_format_in_filename(self, tmp_path): + item = Video(GIF_BYTES).transform(step=3, path=tmp_path) + assert item.filename.endswith(".gif") diff --git a/tests/unit/sdk/internal/run/test_run_log_media.py b/tests/unit/sdk/internal/run/test_run_log_media.py new file mode 100644 index 000000000..b1a8451a8 --- /dev/null +++ b/tests/unit/sdk/internal/run/test_run_log_media.py @@ -0,0 +1,145 @@ +""" +@author: cunyue +@file: test_run_log_media.py +@time: 2026/3/15 +@description: SwanLabRun.log_text / log_image / log_audio / log_video 的参数化测试 +""" + +import threading +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from swanlab.sdk.internal.run import SwanLabRun +from swanlab.sdk.internal.run.transforms.audio import Audio +from swanlab.sdk.internal.run.transforms.image import Image +from swanlab.sdk.internal.run.transforms.text import Text +from swanlab.sdk.internal.run.transforms.video import Video + +# 最小合法 GIF89a(1×1 像素) +_GIF_1X1 = ( + b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00" + b"!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" +) + + +def _make_text() -> Text: + return Text(content="hello world") + + +def _make_audio() -> Audio: + return Audio(np.zeros((1, 4410), dtype=np.float32), sample_rate=44100) + + +def _make_image() -> Image: + return Image(np.zeros((10, 10, 3), dtype=np.uint8)) + + +def _make_video() -> Video: + return Video(_GIF_1X1) + + +# ────────────────────────────────────────────── +# Mock run +# ────────────────────────────────────────────── + + +class _MockRun: + """最小化的 SwanLabRun 替身,供非绑定方法测试使用""" + + def __init__(self): + self._state = "running" + self._api_lock = threading.RLock() + self.log = MagicMock() + + +# ────────────────────────────────────────────── +# 参数化用例 +# ────────────────────────────────────────────── + +# (method_name, media_factory, extra_kwargs_for_method) +_MEDIA_LOG_CASES = [ + pytest.param("log_text", _make_text, {}, id="log_text"), + pytest.param("log_image", _make_image, {}, id="log_image"), + pytest.param("log_audio", _make_audio, {"sample_rate": 44100}, id="log_audio"), + pytest.param("log_video", _make_video, {}, id="log_video"), +] + + +def _call_step(method_name, mock_run, key, data, extra, caption=None, step=None): + method = getattr(SwanLabRun, method_name) + return method(mock_run, key, data, **extra, caption=caption, step=step) + + +@pytest.mark.parametrize("method_name,factory,extra", _MEDIA_LOG_CASES) +class TestRunLogMedia: + """log_text / log_image / log_audio / log_video 的通用契约测试 + 没有log_echarts和log_object3D,因为他们的操作比较复杂 + """ + + def test_single_instance_wrapped_in_list(self, method_name, factory, extra): + """传入单个媒体实例 → log 被调用,data 会被包装""" + mock_run = _MockRun() + instance = factory() + _call_step(method_name, mock_run, "key", instance, extra) + mock_run.log.assert_called_once() + log_data = mock_run.log.call_args[0][0] + assert "key" in log_data + assert len(log_data["key"]) == 1 + assert log_data["key"] != [instance] + assert type(log_data["key"][0]) is instance.__class__ + + def test_list_of_instances_passed_through(self, method_name, factory, extra): + """传入实例列表 → log 被调用,data 保持列表不变""" + mock_run = _MockRun() + instances = [factory(), factory()] + _call_step(method_name, mock_run, "metrics/a", instances, extra) + mock_run.log.assert_called_once() + log_data = mock_run.log.call_args[0][0] + assert log_data["metrics/a"] != instances + assert all(isinstance(item, factory().__class__) for item in log_data["metrics/a"]) + + def test_step_forwarded(self, method_name, factory, extra): + """step 参数应被原样透传给 log""" + mock_run = _MockRun() + _call_step(method_name, mock_run, "k", factory(), extra, step=7) + _, kwargs = mock_run.log.call_args + assert kwargs.get("step") == 7 + + def test_step_none_by_default(self, method_name, factory, extra): + """不传 step 时,log 收到 step=None""" + mock_run = _MockRun() + _call_step(method_name, mock_run, "k", factory(), extra) + _, kwargs = mock_run.log.call_args + assert kwargs.get("step") is None + + def test_raises_when_not_running(self, method_name, factory, extra): + """run 未激活时应抛出 RuntimeError""" + mock_run = _MockRun() + mock_run._state = "finished" + with pytest.raises(RuntimeError, match="requires an active SwanLabRun"): + _call_step(method_name, mock_run, "k", factory(), extra) + + +def _call_caption(method_name, mock_run, key, data, extra, caption=None, step=None): + method = getattr(SwanLabRun, method_name) + return method(mock_run, key, data, **extra, caption=caption, step=step) + + +@pytest.mark.parametrize("method_name,factory,extra", _MEDIA_LOG_CASES) +class TestRunLogMediaCaption: + """caption 参数行为(对所有媒体类型一致)""" + + def test_caption_applied_to_raw_data(self, method_name, factory, extra): + """传入原始数据 + caption → normalize 后的对象 caption 正确""" + mock_run = _MockRun() + instance = factory() + _call_caption(method_name, mock_run, "k", instance, extra, caption="my caption") + # We can't inspect the caption easily without calling transform, so just + # verify log was called with a list + mock_run.log.assert_called_once() + log_data = mock_run.log.call_args[0][0] + assert isinstance(log_data["k"], list) + assert len(log_data["k"]) == 1 + assert log_data["k"][0].caption == "my caption" From 3d44e8e0cc1f5a50e2969bde51c5df86f558619a Mon Sep 17 00:00:00 2001 From: Kang Li Date: Sun, 15 Mar 2026 22:44:38 +0800 Subject: [PATCH 3/5] Refactor run commands; add media log APIs Consolidate top-level run-related commands into a single cmd/run.py that generates wrappers (_make_run_cmd) for SwanLabRun methods (finish, log, log_scalar, log_text, log_image, log_audio, log_video). Remove legacy cmd/finish.py and cmd/log.py and update sdk/__init__.py and top-level swanlab/__init__.py to export the new logging APIs. Add a comprehensive type-stub swanlab/__init__.pyi describing the public API. Improve SwanLabRun lifecycle handling: register atexit and sys.excepthook handlers, add private _atexit_cleanup, _excepthook and _cleanup to centralize teardown and resource reset. Update exceptions export list and remove the example main.py. Tests updated: remove old cmd tests, rename some internal run tests, and add unit tests for the finish/exception hooks to validate the new lifecycle behavior. --- swanlab/__init__.py | 16 +- swanlab/__init__.pyi | 409 ++++++++++++++++++ swanlab/exceptions.py | 2 + swanlab/main.py | 22 - swanlab/sdk/__init__.py | 7 +- swanlab/sdk/cmd/finish.py | 107 ----- swanlab/sdk/cmd/log.py | 92 ---- swanlab/sdk/cmd/run.py | 43 ++ swanlab/sdk/internal/run/__init__.py | 87 +++- tests/unit/sdk/cmd/finish/test_finish.py | 133 ------ tests/unit/sdk/cmd/finish/test_finish_e2e.py | 2 +- tests/unit/sdk/cmd/test_log.py | 64 --- .../unit/sdk/internal/run/test_finish_hook.py | 90 ++++ ...est_run_log_media.py => test_log_media.py} | 0 ...t_decorators.py => test_run_decorators.py} | 0 15 files changed, 629 insertions(+), 445 deletions(-) create mode 100644 swanlab/__init__.pyi delete mode 100644 swanlab/main.py delete mode 100644 swanlab/sdk/cmd/finish.py delete mode 100644 swanlab/sdk/cmd/log.py create mode 100644 swanlab/sdk/cmd/run.py delete mode 100644 tests/unit/sdk/cmd/finish/test_finish.py delete mode 100644 tests/unit/sdk/cmd/test_log.py create mode 100644 tests/unit/sdk/internal/run/test_finish_hook.py rename tests/unit/sdk/internal/run/{test_run_log_media.py => test_log_media.py} (100%) rename tests/unit/sdk/internal/run/{test_decorators.py => test_run_decorators.py} (100%) diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 25555ca80..bb8202a6c 100644 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -2,16 +2,20 @@ Settings, SwanLabRun, config, + define_scalar, finish, get_run, has_run, init, log, + log_audio, + log_image, log_text, + log_video, login, merge_settings, ) -from swanlab.sdk.internal.run.transforms import Audio, Text +from swanlab.sdk.internal.run.transforms import Audio, Image, Text, Video from swanlab.sdk.utils.version import get_swanlab_version from . import utils @@ -25,15 +29,21 @@ "login", "log", "log_text", + "log_image", + "log_audio", + "log_video", + "define_scalar", # run "SwanLabRun", "has_run", "get_run", + # config + "config", # data "Text", "Audio", - # config - "config", + "Image", + "Video", # utils "utils", ] diff --git a/swanlab/__init__.pyi b/swanlab/__init__.pyi new file mode 100644 index 000000000..dfaf08fec --- /dev/null +++ b/swanlab/__init__.pyi @@ -0,0 +1,409 @@ +""" +swanlab public API type stubs. + +This file is the single source of truth for all type signatures exposed at the +top-level `swanlab` namespace. Add new public symbols here when they are added +to swanlab/__init__.py. +""" + +from typing import Any, List, Mapping, Optional, Union + +from . import utils +from .sdk import config +from .sdk.cmd.init import ConfigLike +from .sdk.internal.run import SwanLabRun +from .sdk.internal.run.transforms import Audio, Image, Text, Video +from .sdk.internal.settings import Settings +from .sdk.typings.run import FinishType, ModeType, ResumeType +from .sdk.typings.run.column import ScalarXAxisType +from .sdk.utils.callbacker import SwanLabCallback + +__version__: str + +__all__ = [ + # cmd + "merge_settings", + "Settings", + "init", + "finish", + "login", + "log", + "log_text", + "log_image", + "log_audio", + "log_video", + "define_scalar", + # run + "SwanLabRun", + "has_run", + "get_run", + # data + "Text", + "Audio", + "Image", + "Video", + # config + "config", + # utils + "utils", +] + +# ── lifecycle ────────────────────────────────────────────────────────────────── + +def init( + *, + reinit: Optional[bool] = None, + logdir: Optional[str] = None, + mode: Optional[ModeType] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, + public: Optional[bool] = None, + name: Optional[str] = None, + color: Optional[str] = None, + description: Optional[str] = None, + job_type: Optional[str] = None, + group: Optional[str] = None, + tags: Optional[List[str]] = None, + id: Optional[str] = None, + resume: Optional[Union[ResumeType, bool]] = None, + config: Optional[ConfigLike] = None, + settings: Optional[Settings] = None, + callbacks: Optional[List[SwanLabCallback]] = None, + **kwargs: Any, +) -> SwanLabRun: + """Initialize a new SwanLab run to track experiments. + + This function starts a new run for logging metrics, artifacts, and metadata. + After calling this, use `swanlab.log()` to log data and `swanlab.finish()` to + close the run. SwanLab automatically finishes runs at program exit. + + :param reinit: If True, finish the current run before starting a new one. Defaults to False. + :param logdir: Directory to store logs. Defaults to "./swanlog". + :param mode: Run mode. Options: "cloud" (sync to cloud), "local" (local only), + "offline" (save locally for later sync), "disabled" (no logging). Defaults to "cloud". + :param workspace: Workspace or organization name. Defaults to current user. + :param project: Project name. Defaults to current directory name. + :param public: Make project publicly visible (cloud mode only). Defaults to False. + :param name: Experiment name. Auto-generated if not provided. + :param color: Experiment color for visualization. Auto-generated if not provided. + :param description: Experiment description. + :param job_type: Job type label (e.g., "train", "eval"). + :param group: Group name for organizing related experiments. + :param tags: List of tags for categorizing experiments. + :param id: Run ID for resuming a previous run (cloud mode only). + :param resume: Resume behavior. Options: "must" (must resume), "allow" (resume if exists), + "never" (always create new). Defaults to "never". + :param config: Experiment configuration dict or path to config file (JSON/YAML). + :param settings: Custom Settings object for advanced configuration. + :param callbacks: List of callback functions triggered on run events. + :return: The initialized SwanLabRun object. + :raises RuntimeError: If a run is already active and reinit=False. + + Examples: + + Basic local run: + + >>> import swanlab + >>> swanlab.init(mode="local", project="my_project") + >>> swanlab.log({"loss": 0.5}) + >>> swanlab.finish() + + Cloud run with configuration: + + >>> import swanlab + >>> swanlab.login(api_key="your_key") + >>> swanlab.init( + ... mode="cloud", + ... project="image_classification", + ... name="resnet50_experiment", + ... config={"lr": 0.001, "batch_size": 32} + ... ) + >>> swanlab.log({"accuracy": 0.95}) + >>> swanlab.finish() + """ + ... + +def finish(state: FinishType = "success", error: Optional[str] = None) -> None: + """Finish the current run and close the experiment. + + This function safely closes the current run and waits for all logs to be flushed. + SwanLab automatically calls this function at program exit, but you can call it + manually to mark the experiment as completed with a specific state. + + :param state: Final state of the run. Must be one of: "success", "crashed", "aborted". + Defaults to "success". + :param error: Error message if state is "crashed". Required when state="crashed". + :raises RuntimeError: If called without an active run. + + Examples: + + Finish a successful run: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log({"loss": 0.5}) + >>> swanlab.finish() + + Mark run as crashed with error message: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> try: + ... raise ValueError("Training failed") + ... except Exception as e: + ... swanlab.finish(state="crashed", error=str(e)) + """ + ... + +def login( + api_key: Optional[str] = None, + relogin: bool = False, + host: Optional[str] = None, + save: bool = False, + timeout: int = 10, +) -> bool: + """Authenticate with SwanLab Cloud. + + This function authenticates your environment with SwanLab. If already logged in + and `relogin` is False, this function does nothing. Call this before `swanlab.init()` + to use cloud features. + + :param api_key: Your SwanLab API key. If not provided, will attempt to read from + environment or prompt for input. + :param relogin: If True, forces re-authentication and overwrites existing credentials. + Defaults to False. + :param host: Custom API host URL. If not provided, uses the default SwanLab cloud host. + :param save: Whether to save the API key locally for future sessions. Defaults to False. + :param timeout: Network request timeout in seconds. Defaults to 10. + :return: True if login was successful, False otherwise. + :raises RuntimeError: If called while a run is active. + :raises AuthenticationError: If login fails due to invalid credentials or network issues. + + Examples: + + Login with an API key: + + >>> import swanlab + >>> swanlab.login(api_key="your_api_key_here") + >>> swanlab.init(mode="cloud") + + Force re-login and save credentials: + + >>> import swanlab + >>> swanlab.login(api_key="new_api_key", relogin=True, save=True) + """ + ... + +def merge_settings(settings: Union[Settings, dict]) -> None: + """Merge custom settings into the global SwanLab configuration. + + This function allows you to customize SwanLab's behavior before initializing a run. + It must be called before `swanlab.init()`. + + :param settings: Custom settings to merge. Can be either a Settings object or a dict. + :raises RuntimeError: If called while a run is active. + + Examples: + + >>> import swanlab + >>> swanlab.merge_settings({"mode": "local", "logdir": "./my_logs"}) + >>> swanlab.init() + """ + ... + +# ── run access ───────────────────────────────────────────────────────────────── + +def has_run() -> bool: + """Check if there is an active SwanLab run. + + :return: True if a run is currently active, False otherwise. + + Examples: + + >>> import swanlab + >>> if swanlab.has_run(): + ... swanlab.log({"metric": 1.0}) + ... else: + ... print("No active run") + """ + ... + +def get_run() -> SwanLabRun: + """Get the current active SwanLab run. + + :return: The active SwanLabRun instance. + :raises RuntimeError: If no run is currently active. + + Examples: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> run = swanlab.get_run() + >>> print(run.id) + >>> swanlab.finish() + """ + ... + +# ── logging ──────────────────────────────────────────────────────────────────── + +def log(data: Mapping[str, Any], step: Optional[int] = None) -> None: + """Log metrics and data to the current run. + + :param data: Dictionary of metric names and values to log. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + Log multiple metrics: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log({"loss": 0.5, "accuracy": 0.95}) + >>> swanlab.finish() + + Log with explicit step: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log({"loss": 0.5}, step=10) + >>> swanlab.finish() + """ + ... + +def log_text( + key: str, + data: Union[str, Text, List[str], List[Text]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging text data. + + :param key: The key for the text data. + :param data: The text data itself or a Text object. + :param caption: Optional caption for the text data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + Log simple text: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log_text("output", "Training started") + >>> swanlab.finish() + + Log text with caption: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log_text("prediction", "cat", caption="Model output") + >>> swanlab.finish() + """ + ... + +def log_image( + key: str, + data: Union[Image, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging image data. + + :param key: The key for the image data. + :param data: The image data itself or an Image object. + :param caption: Optional caption for the image data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab, numpy as np + >>> swanlab.init(mode="local") + >>> img = np.zeros((64, 64, 3), dtype=np.uint8) + >>> swanlab.log_image("sample", img, caption="blank image") + >>> swanlab.finish() + """ + ... + +def log_audio( + key: str, + data: Union[Audio, Any, List[Any]], + sample_rate: int = 44100, + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging audio data. + + :param key: The key for the audio data. + :param data: The audio data itself or an Audio object. + :param sample_rate: Sample rate of the audio (used when data is raw numpy array). + :param caption: Optional caption for the audio data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab, numpy as np + >>> swanlab.init(mode="local") + >>> audio = np.zeros((1, 44100), dtype=np.float32) + >>> swanlab.log_audio("sound", audio, sample_rate=44100) + >>> swanlab.finish() + """ + ... + +def log_video( + key: str, + data: Union[Video, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging video data. + + Currently supported formats: GIF. + + :param key: The key for the video data. + :param data: The video data itself or a Video object. + :param caption: Optional caption for the video data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> with open("animation.gif", "rb") as f: + ... swanlab.log_video("rollout", f.read()) + >>> swanlab.finish() + """ + ... + +def define_scalar( + key: str, + name: Optional[str] = None, + color: Optional[str] = None, + x_axis: Optional[ScalarXAxisType] = None, + chart_name: Optional[str] = None, +) -> None: + """Explicitly define a scalar column. + + Call this before logging to customize how a scalar metric is displayed, + such as setting a display name, color, or x-axis type. + + :param key: The key for the scalar column. Supports glob patterns (e.g. "train/*") to match multiple columns at once. + :param name: Optional display name for the scalar column. + :param color: Optional hex color for the scalar line in charts. + :param x_axis: Optional x-axis type. One of "_step", "_relative_time", or a custom key. + :param chart_name: Optional name for the chart group this column belongs to. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.define_scalar("loss", color="#FF5733", x_axis="_step") + >>> swanlab.log({"loss": 0.5}) + >>> swanlab.finish() + """ + ... diff --git a/swanlab/exceptions.py b/swanlab/exceptions.py index 525b2fd98..1ac3216ef 100644 --- a/swanlab/exceptions.py +++ b/swanlab/exceptions.py @@ -9,6 +9,8 @@ from requests.exceptions import HTTPError +__all__ = ["ApiError", "AuthenticationError", "DataStoreError"] + class ApiError(HTTPError): """ diff --git a/swanlab/main.py b/swanlab/main.py deleted file mode 100644 index 04e774a6e..000000000 --- a/swanlab/main.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -@author: cunyue -@file: main.py -@time: 2026/3/5 13:11 -@description: hello world -""" - -from typing import TypedDict - - -def main(): - t: Test = {"a": 2, "b": 2} - print("hello world", t["a"], t["b"]) - - -class Test(TypedDict): - a: int - b: int - - -if __name__ == "__main__": - main() diff --git a/swanlab/sdk/__init__.py b/swanlab/sdk/__init__.py index b966226aa..570350141 100644 --- a/swanlab/sdk/__init__.py +++ b/swanlab/sdk/__init__.py @@ -5,11 +5,10 @@ @description: SwanLab SDK,负责SwanLab库的核心指标上传功能 """ -from swanlab.sdk.cmd.finish import finish from swanlab.sdk.cmd.init import init -from swanlab.sdk.cmd.log import log, log_text from swanlab.sdk.cmd.login import login from swanlab.sdk.cmd.merge_settings import Settings, merge_settings +from swanlab.sdk.cmd.run import define_scalar, finish, log, log_audio, log_image, log_text, log_video from swanlab.sdk.internal.run import SwanLabRun, clear_run, get_run, has_run, set_run from swanlab.sdk.internal.run.config import config @@ -21,6 +20,10 @@ "login", "log", "log_text", + "log_image", + "log_audio", + "log_video", + "define_scalar", "SwanLabRun", "has_run", "get_run", diff --git a/swanlab/sdk/cmd/finish.py b/swanlab/sdk/cmd/finish.py deleted file mode 100644 index 147cbd276..000000000 --- a/swanlab/sdk/cmd/finish.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -@author: cunyue -@file: finish.py -@time: 2026/3/6 21:49 -@description: SwanLab SDK 结束当前运行 -""" - -import atexit -import sys -import traceback -from types import TracebackType -from typing import Optional, Type - -from swanlab.sdk.cmd.helper import with_cmd_lock, with_run -from swanlab.sdk.internal.pkg import console -from swanlab.sdk.internal.run import get_run, has_run -from swanlab.sdk.typings.run import FinishType - - -@with_cmd_lock -@with_run("finish") -def finish(state: FinishType = "success", error: Optional[str] = None): - """Finish the current run and close the experiment. - - This function safely closes the current run and waits for all logs to be flushed. - SwanLab automatically calls this function at program exit, but you can call it - manually to mark the experiment as completed with a specific state. - - :param state: Final state of the run. Must be one of: "success", "crashed", "aborted". - Defaults to "success". - - :param error: Error message if state is "crashed". Required when state="crashed". - - :raises RuntimeError: If called without an active run. - - Examples: - - Finish a successful run: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}) - >>> swanlab.finish() - - Mark run as crashed with error message: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> try: - ... # training code - ... raise ValueError("Training failed") - ... except Exception as e: - ... swanlab.finish(state="crashed", error=str(e)) - - Mark run as aborted: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}) - >>> swanlab.finish(state="aborted") - """ - run = get_run() - run.finish(state, error) - - -def atexit_finish(): - """ - 全局退出时自动结束当前运行,此时代码正常执行完毕 - """ - if not has_run(): - return - console.debug("SwanLab Run is finishing at exit...") - run = get_run() - run.finish() - - -atexit.register(atexit_finish) - - -def swanlab_excepthook(tp: Type[BaseException], val: BaseException, tb: Optional[TracebackType]): - """全局异常捕获,用于将实验标记为 crashed""" - try: - if not has_run(): - return - state: FinishType = "crashed" - if tp is KeyboardInterrupt: - console.info("KeyboardInterrupt by user") - state = "aborted" - else: - console.info("Error happened while training") - # 生成错误堆栈 - full_error_msg = "".join(traceback.format_exception(tp, val, tb)) - - # 打印错误堆栈 - run = get_run() - run.finish(state=state, error=full_error_msg) - - except Exception as e: - console.error(f"SwanLab failed to handle excepthook: {e}", file=sys.stderr) - finally: - # _original_excepthook = sys.excepthook - # 不要用动态保存的 _original_excepthook,直接调用 Python 底层 C 实现的 sys.__excepthook__ - # 确保异常信息被正确打印 - sys.__excepthook__(tp, val, tb) - - -sys.excepthook = swanlab_excepthook diff --git a/swanlab/sdk/cmd/log.py b/swanlab/sdk/cmd/log.py deleted file mode 100644 index 4abc99715..000000000 --- a/swanlab/sdk/cmd/log.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -@author: cunyue -@file: log.py -@time: 2026/3/14 -@description: SwanLab SDK logging methods -""" - -from typing import Any, Mapping, Optional, Union - -from swanlab.sdk.cmd.helper import with_cmd_lock, with_run -from swanlab.sdk.internal.run import get_run -from swanlab.sdk.internal.run.transforms import Text - - -@with_cmd_lock -@with_run("log") -def log(data: Mapping[str, Any], step: Optional[int] = None): - """Log metrics and data to the current run. - - :param data: Dictionary of metric names and values to log. - - :param step: Optional step number. If not provided, auto-increments. - - :raises RuntimeError: If called without an active run. - - Examples: - - Log single metric: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}) - >>> swanlab.finish() - - Log multiple metrics: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5, "accuracy": 0.95}) - >>> swanlab.finish() - - Log with explicit step: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}, step=10) - >>> swanlab.finish() - """ - run = get_run() - run.log(data, step) - - -@with_cmd_lock -@with_run("log_text") -def log_text(key: str, data: Union[str, Text], caption: Optional[str] = None, step: Optional[int] = None): - """Log text data to the current run. - - :param key: The key for the text data. - - :param data: The text data itself or a Text object. - - :param caption: Optional caption for the text data. - - :param step: Optional step number. If not provided, auto-increments. - - :raises RuntimeError: If called without an active run. - - Examples: - - Log simple text: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log_text("output", "Training started") - >>> swanlab.finish() - - Log text with caption: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log_text("prediction", "cat", caption="Model output") - >>> swanlab.finish() - - Log text with step: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log_text("status", "Epoch 5 complete", step=5) - >>> swanlab.finish() - """ - run = get_run() - run.log_text(key, data, caption, step) diff --git a/swanlab/sdk/cmd/run.py b/swanlab/sdk/cmd/run.py new file mode 100644 index 000000000..f1b30955e --- /dev/null +++ b/swanlab/sdk/cmd/run.py @@ -0,0 +1,43 @@ +""" +@author: cunyue +@file: run.py +@time: 2026/3/14 +@description: SwanLab SDK run methods + +_make_run_cmd 将 SwanLabRun 上的方法包装为顶层 cmd 函数: + - __wrapped__ 指向 SwanLabRun 原方法,help() / inspect 可追溯 + - __doc__ 复用原方法文档,无需重复维护 + +新增公开方法时,在 SwanLabRun 上实现后,在此文件末尾追加一行: + xxx = _make_run_cmd("xxx") +并同步更新 swanlab/__init__.pyi 中的函数声明。 +""" + +from typing import Any, Callable + +from swanlab.sdk.cmd.helper import with_cmd_lock, with_run +from swanlab.sdk.internal.run import SwanLabRun, get_run + + +def _make_run_cmd(method_name: str) -> Callable: + run_method = getattr(SwanLabRun, method_name) + + @with_cmd_lock + @with_run(method_name) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return getattr(get_run(), method_name)(*args, **kwargs) + + wrapper.__name__ = method_name + wrapper.__wrapped__ = run_method # type: ignore[attr-defined] + wrapper.__doc__ = run_method.__doc__ + return wrapper + + +# ── 每新增一个公开方法,在此追加一行 ────────────────────────────────────────── +log = _make_run_cmd("log") +log_text = _make_run_cmd("log_text") +log_image = _make_run_cmd("log_image") +log_audio = _make_run_cmd("log_audio") +log_video = _make_run_cmd("log_video") +define_scalar = _make_run_cmd("define_scalar") +finish = _make_run_cmd("finish") diff --git a/swanlab/sdk/internal/run/__init__.py b/swanlab/sdk/internal/run/__init__.py index d9477bc91..57907a534 100644 --- a/swanlab/sdk/internal/run/__init__.py +++ b/swanlab/sdk/internal/run/__init__.py @@ -8,10 +8,14 @@ 3. 触发异步微批处理落盘与回调 """ +import atexit +import sys import threading +import traceback from functools import cached_property, wraps from pathlib import Path -from typing import Any, List, Literal, Mapping, Optional, Union, cast, get_args +from types import TracebackType +from typing import Any, List, Literal, Mapping, Optional, Type, Union, cast, get_args from google.protobuf.timestamp_pb2 import Timestamp @@ -137,15 +141,64 @@ def __init__(self, ctx: RunContext): # 设置全局运行实例 set_run(self) - - # 绑定日志文件 + # 注册退出钩子 + self._sys_origin_excepthook = sys.excepthook + atexit.register(self._atexit_cleanup) + sys.excepthook = self._excepthook + # 绑定日志文件,运行正式开始 if self._ctx.config.settings.mode != "disabled": log.bindfile(self._ctx.debug_dir) # ---------------------------------- - # 属性 (Properties) + # 私有钩子 # ---------------------------------- + def _atexit_cleanup(self) -> None: + """程序正常退出时自动结束当前运行""" + if self._state != "running": + return + console.debug("SwanLab Run is finishing at exit...") + self.finish() + + def _excepthook( + self, + tp: Type[BaseException], + val: BaseException, + tb: Optional[TracebackType], + ) -> None: + """全局异常捕获,将实验标记为 crashed 或 aborted""" + try: + if self._state != "running": + return + state: FinishType = "crashed" + if tp is KeyboardInterrupt: + console.info("KeyboardInterrupt by user") + state = "aborted" + else: + console.info("Error happened while training") + full_error_msg = "".join(traceback.format_exception(tp, val, tb)) + self.finish(state=state, error=full_error_msg) + except Exception as e: + console.error(f"SwanLab failed to handle excepthook: {e}") + finally: + sys.__excepthook__(tp, val, tb) + + def _cleanup(self): + """ + 清除副作用 + """ + # 取消钩子 + console.debug("Cleanup system hook...") + atexit.unregister(self._atexit_cleanup) + sys.excepthook = self._sys_origin_excepthook + # 清理全局运行实例 + console.debug("Cleanup global instance...") + clear_run() + deactivate_run_config() + console.debug("Clean & tidy! ( •̀ ω •́ )y") + # 释放日志,本次运行结束 + log.reset() + @cached_property def id(self) -> str: """ @@ -340,18 +393,14 @@ def define_scalar( chart_name: Optional[str] = None, ): """ - Explicitly define a scalar column. - - :param key: The key for the scalar column. - - :param name: Optional name for the scalar column. - - :param color: Optional color for the scalar column. - - :param x_axis: Optional x-axis for the scalar column. - - :param chart_name: Optional name for the chart. + 手动定义一个标量列 + :param key: 标量列的键,支持通配符(如 "train/*")以匹配多个列 + :param name: 标量列的可选显示名称 + :param color: 标量列的可选颜色 + :param x_axis: 标量列的可选 x 轴类型 + :param chart_name: 标量列所属的可选图表名称 """ + # TODO: 实现 glob 匹配逻辑 if not (this_key := fmt.safe_validate_key(key)): return console.error( f"Invalid key for define scalar: {key}, please use valid characters (alphanumeric, '.', '-', '/') and avoid special characters." @@ -409,12 +458,8 @@ def finish(self, state: FinishType = "success", error: Optional[str] = None): self._emitter.emit(RunFinishEvent(state=this_state, error=error, timestamp=ts)) # 阻塞主线程,等待后台队列消费完毕 self._consumer.join() - # 清理全局运行实例 - clear_run() - console.debug(f"Run finished with state: {state}") - # 释放一些资源 - log.reset() - deactivate_run_config() + console.debug(f"SwanLab Run has finished with state: {self._state}, cleanup...") + self._cleanup() _current_run: Optional[SwanLabRun] = None diff --git a/tests/unit/sdk/cmd/finish/test_finish.py b/tests/unit/sdk/cmd/finish/test_finish.py deleted file mode 100644 index e101c57ed..000000000 --- a/tests/unit/sdk/cmd/finish/test_finish.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -@author: cunyue -@file: test_finish.py -@time: 2026/3/14 -@description: 测试 swanlab.sdk.cmd.finish 中各函数的单元行为(均 mock 依赖,不启动真实 Run) -""" - -import sys -from unittest.mock import ANY, MagicMock - -from swanlab.sdk.cmd.finish import atexit_finish, finish, swanlab_excepthook - - -def _make_exc_info(exc: BaseException): - """辅助:构造 (tp, val, tb) 三元组,用于测试 excepthook""" - try: - raise exc - except BaseException: - tp, val, tb = sys.exc_info() - assert tp is not None and val is not None - return tp, val, tb - - -class TestFinishFunction: - def test_finish_calls_run_finish(self, monkeypatch): - """应将 state 和 error 透传给 run.finish()""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - - finish(state="crashed", error="something went wrong") - - mock_run.finish.assert_called_once_with("crashed", "something went wrong") - - def test_finish_default_state_is_success(self, monkeypatch): - """finish() 不传参时,默认 state 为 'success',error 为 None""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - - finish() - - mock_run.finish.assert_called_once_with("success", None) - - -class TestAtexitFinish: - def test_atexit_finish_no_run(self, monkeypatch): - """无活跃 Run 时,atexit_finish 直接返回,不调用 get_run""" - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: False) - mock_get_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", mock_get_run) - - atexit_finish() - - mock_get_run.assert_not_called() - - def test_atexit_finish_calls_run_finish(self, monkeypatch): - """有活跃 Run 时,atexit_finish 应调用 run.finish()""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - - atexit_finish() - - mock_run.finish.assert_called_once() - - -class TestSwanlabExcepthook: - def test_excepthook_keyboard_interrupt(self, monkeypatch): - """KeyboardInterrupt 异常 → run.finish(state='aborted', ...)""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - monkeypatch.setattr(sys, "__excepthook__", MagicMock()) - - tp, val, tb = _make_exc_info(KeyboardInterrupt()) - swanlab_excepthook(tp, val, tb) - - mock_run.finish.assert_called_once_with(state="aborted", error=ANY) - - def test_excepthook_generic_exception(self, monkeypatch): - """普通异常 → run.finish(state='crashed'),error 包含完整 traceback""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - monkeypatch.setattr(sys, "__excepthook__", MagicMock()) - - tp, val, tb = _make_exc_info(RuntimeError("boom")) - swanlab_excepthook(tp, val, tb) - - call_kwargs = mock_run.finish.call_args.kwargs - assert call_kwargs["state"] == "crashed" - assert "boom" in call_kwargs["error"] - - def test_excepthook_no_run(self, monkeypatch): - """无活跃 Run 时,不调用 run.finish,但不抛出异常""" - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: False) - mock_get_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", mock_get_run) - monkeypatch.setattr(sys, "__excepthook__", MagicMock()) - - tp, val, tb = _make_exc_info(RuntimeError("no run")) - swanlab_excepthook(tp, val, tb) - - mock_get_run.assert_not_called() - - def test_excepthook_always_calls_original_hook(self, monkeypatch): - """无论是否有活跃 Run,sys.__excepthook__ 必须被调用一次""" - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: False) - mock_original = MagicMock() - monkeypatch.setattr(sys, "__excepthook__", mock_original) - - tp, val, tb = _make_exc_info(RuntimeError("test")) - swanlab_excepthook(tp, val, tb) - - mock_original.assert_called_once_with(tp, val, tb) - - def test_excepthook_internal_error_doesnt_crash(self, monkeypatch): - """excepthook 内部逻辑出错时,不应向上抛出异常,且仍调用 sys.__excepthook__""" - # 模拟 has_run 本身抛出异常(触发 except 块) - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", MagicMock(side_effect=Exception("internal boom"))) - # mock console 以隔离 console.error 的实现细节 - monkeypatch.setattr("swanlab.sdk.cmd.finish.console", MagicMock()) - mock_original = MagicMock() - monkeypatch.setattr(sys, "__excepthook__", mock_original) - - tp, val, tb = _make_exc_info(RuntimeError("outer")) - - # 不应抛出 - swanlab_excepthook(tp, val, tb) - - # __excepthook__ 仍须被调用(finally 块保证) - mock_original.assert_called_once_with(tp, val, tb) diff --git a/tests/unit/sdk/cmd/finish/test_finish_e2e.py b/tests/unit/sdk/cmd/finish/test_finish_e2e.py index 12a4aea11..f916d80ff 100644 --- a/tests/unit/sdk/cmd/finish/test_finish_e2e.py +++ b/tests/unit/sdk/cmd/finish/test_finish_e2e.py @@ -9,8 +9,8 @@ import pytest import swanlab -from swanlab.sdk.cmd.finish import finish from swanlab.sdk.cmd.init import init +from swanlab.sdk.cmd.run import finish from swanlab.sdk.internal.run import has_run diff --git a/tests/unit/sdk/cmd/test_log.py b/tests/unit/sdk/cmd/test_log.py deleted file mode 100644 index f64a9d5a0..000000000 --- a/tests/unit/sdk/cmd/test_log.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -@author: cunyue -@file: test_log.py -@time: 2026/3/14 -@description: 测试 swanlab.sdk.cmd.log 中的函数 -""" - -from unittest.mock import MagicMock - -from swanlab.sdk.cmd.log import log, log_text - - -class TestLog: - def test_log_calls_run_log(self, monkeypatch): - """log() 应调用 run.log() 并传递参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log({"loss": 0.5, "accuracy": 0.95}) - - mock_run.log.assert_called_once_with({"loss": 0.5, "accuracy": 0.95}, None) - - def test_log_with_step(self, monkeypatch): - """log() 应正确传递 step 参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log({"loss": 0.5}, step=10) - - mock_run.log.assert_called_once_with({"loss": 0.5}, 10) - - -class TestLogText: - def test_log_text_calls_run_log_text(self, monkeypatch): - """log_text() 应调用 run.log_text() 并传递参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log_text("output", "Training started") - - mock_run.log_text.assert_called_once_with("output", "Training started", None, None) - - def test_log_text_with_caption(self, monkeypatch): - """log_text() 应正确传递 caption 参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log_text("prediction", "cat", caption="Model output") - - mock_run.log_text.assert_called_once_with("prediction", "cat", "Model output", None) - - def test_log_text_with_step(self, monkeypatch): - """log_text() 应正确传递 step 参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log_text("status", "Epoch complete", step=5) - - mock_run.log_text.assert_called_once_with("status", "Epoch complete", None, 5) diff --git a/tests/unit/sdk/internal/run/test_finish_hook.py b/tests/unit/sdk/internal/run/test_finish_hook.py new file mode 100644 index 000000000..9be86d83f --- /dev/null +++ b/tests/unit/sdk/internal/run/test_finish_hook.py @@ -0,0 +1,90 @@ +""" +@author: cunyue +@file: test_finish_hook.py +@time: 2026/3/14 +@description: 测试 SwanLabRun._atexit_cleanup / _excepthook 的单元行为(均 mock 依赖,不启动真实 Run) +""" + +import sys +import threading +from unittest.mock import ANY, MagicMock, patch + +from swanlab.sdk.internal.run import SwanLabRun + + +def _make_exc_info(exc: BaseException): + """辅助:构造 (tp, val, tb) 三元组""" + # noinspection PyBroadException + try: + raise exc + except BaseException: + tp, val, tb = sys.exc_info() + assert tp is not None and val is not None + return tp, val, tb + + +def _make_mock_run(state: str = "running") -> MagicMock: + """构造一个最小化的 SwanLabRun 替身""" + mock = MagicMock(spec=SwanLabRun) + mock._state = state + mock._api_lock = threading.RLock() + return mock + + +class TestAtexitCleanup: + def test_no_op_when_not_running(self): + """_state != 'running' 时直接返回,不调用 finish""" + run = _make_mock_run(state="success") + SwanLabRun._atexit_cleanup(run) + run.finish.assert_not_called() + + def test_calls_finish_when_running(self): + """_state == 'running' 时应调用 finish()""" + run = _make_mock_run(state="running") + SwanLabRun._atexit_cleanup(run) + run.finish.assert_called_once() + + +class TestExcepthook: + def test_keyboard_interrupt_calls_aborted(self): + """KeyboardInterrupt → finish(state='aborted', ...)""" + run = _make_mock_run() + with patch("sys.__excepthook__"): + tp, val, tb = _make_exc_info(KeyboardInterrupt()) + SwanLabRun._excepthook(run, tp, val, tb) + run.finish.assert_called_once_with(state="aborted", error=ANY) + + def test_generic_exception_calls_crashed(self): + """普通异常 → finish(state='crashed'),error 包含完整 traceback""" + run = _make_mock_run() + with patch("sys.__excepthook__"): + tp, val, tb = _make_exc_info(RuntimeError("boom")) + SwanLabRun._excepthook(run, tp, val, tb) + call_kwargs = run.finish.call_args.kwargs + assert call_kwargs["state"] == "crashed" + assert "boom" in call_kwargs["error"] + + def test_no_op_when_not_running(self): + """_state != 'running' 时不调用 finish""" + run = _make_mock_run(state="success") + with patch("sys.__excepthook__"): + tp, val, tb = _make_exc_info(RuntimeError("no run")) + SwanLabRun._excepthook(run, tp, val, tb) + run.finish.assert_not_called() + + def test_always_calls_original_excepthook(self): + """无论是否有活跃 Run,sys.__excepthook__ 必须被调用一次""" + run = _make_mock_run(state="success") + with patch("sys.__excepthook__") as mock_original: + tp, val, tb = _make_exc_info(RuntimeError("test")) + SwanLabRun._excepthook(run, tp, val, tb) + mock_original.assert_called_once_with(tp, val, tb) + + def test_internal_error_doesnt_crash(self): + """excepthook 内部出错时不向上抛出,仍调用 sys.__excepthook__""" + run = _make_mock_run() + run.finish.side_effect = Exception("internal boom") + with patch("sys.__excepthook__") as mock_original: + tp, val, tb = _make_exc_info(RuntimeError("outer")) + SwanLabRun._excepthook(run, tp, val, tb) + mock_original.assert_called_once_with(tp, val, tb) diff --git a/tests/unit/sdk/internal/run/test_run_log_media.py b/tests/unit/sdk/internal/run/test_log_media.py similarity index 100% rename from tests/unit/sdk/internal/run/test_run_log_media.py rename to tests/unit/sdk/internal/run/test_log_media.py diff --git a/tests/unit/sdk/internal/run/test_decorators.py b/tests/unit/sdk/internal/run/test_run_decorators.py similarity index 100% rename from tests/unit/sdk/internal/run/test_decorators.py rename to tests/unit/sdk/internal/run/test_run_decorators.py From dbd8bc17a0f18016655f80c74f98d09ddaf8f757 Mon Sep 17 00:00:00 2001 From: Kang Li Date: Mon, 16 Mar 2026 00:02:28 +0800 Subject: [PATCH 4/5] Update __init__.py --- swanlab/sdk/internal/run/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swanlab/sdk/internal/run/__init__.py b/swanlab/sdk/internal/run/__init__.py index 57907a534..2ca12d481 100644 --- a/swanlab/sdk/internal/run/__init__.py +++ b/swanlab/sdk/internal/run/__init__.py @@ -195,7 +195,7 @@ def _cleanup(self): console.debug("Cleanup global instance...") clear_run() deactivate_run_config() - console.debug("Clean & tidy! ( •̀ ω •́ )y") + console.debug("Clean & tidy! ciallo ( ∠・ω< ) ~ ★") # 释放日志,本次运行结束 log.reset() From 719e9a627eaf4f915c571681c67ae0ab02b44e41 Mon Sep 17 00:00:00 2001 From: Kang Li Date: Mon, 16 Mar 2026 00:25:58 +0800 Subject: [PATCH 5/5] Add matplotlib.figure and torch.Tensor support Extend Image constructor type hints and docstring to accept vendor.torch.Tensor and vendor.matplotlib.figure.Figure (matplotlib figures and torch tensors). Also import matplotlib.figure under TYPE_CHECKING and add 'matplotlib.figure' to _SUBMODULE_IMPORTS in vendor/__init__ so the submodule is available for type references. --- swanlab/sdk/internal/run/transforms/image/__init__.py | 11 +++++++++-- swanlab/vendor/__init__.py | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/swanlab/sdk/internal/run/transforms/image/__init__.py b/swanlab/sdk/internal/run/transforms/image/__init__.py index a4112495c..4526f6ec0 100644 --- a/swanlab/sdk/internal/run/transforms/image/__init__.py +++ b/swanlab/sdk/internal/run/transforms/image/__init__.py @@ -50,7 +50,14 @@ def _resize(image: "vendor.PIL.Image.Image", size) -> "vendor.PIL.Image.Image": class Image(TransformMedia): def __init__( self, - data_or_path: Union["Image", str, "vendor.PIL.Image.Image", "vendor.np.ndarray"], + data_or_path: Union[ + "Image", + str, + "vendor.PIL.Image.Image", + "vendor.np.ndarray", + "vendor.torch.Tensor", + "vendor.matplotlib.figure.Figure", + ], mode: Optional[str] = None, caption: Optional[str] = None, file_type: Optional[str] = None, @@ -60,7 +67,7 @@ def __init__( Parameters ---------- - data_or_path: str, PIL.Image.Image, numpy.ndarray, torch.Tensor, matplotlib figure, or Image + data_or_path: str, PIL.Image.Image, numpy.ndarray, torch.Tensor, matplotlib.figure.Figure, or Image Path to an image file (PNG/JPG/JPEG/BMP; GIF is not supported), a PIL Image, numpy array (shape: (H, W) or (H, W, 3/4)), torch.Tensor, matplotlib figure, or another Image instance (nesting). diff --git a/swanlab/vendor/__init__.py b/swanlab/vendor/__init__.py index 3875f0899..75284a95a 100644 --- a/swanlab/vendor/__init__.py +++ b/swanlab/vendor/__init__.py @@ -15,6 +15,7 @@ import boto3 import imageio import matplotlib + import matplotlib.figure import moviepy import numpy as np import pandas as pd @@ -82,6 +83,7 @@ # so their attributes are accessible (e.g. PIL.Image must be imported for PIL.Image to work) _SUBMODULE_IMPORTS = { "PIL": ["PIL.Image"], + "matplotlib": ["matplotlib.figure"], }