diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 25555ca80..bb8202a6c 100644 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -2,16 +2,20 @@ Settings, SwanLabRun, config, + define_scalar, finish, get_run, has_run, init, log, + log_audio, + log_image, log_text, + log_video, login, merge_settings, ) -from swanlab.sdk.internal.run.transforms import Audio, Text +from swanlab.sdk.internal.run.transforms import Audio, Image, Text, Video from swanlab.sdk.utils.version import get_swanlab_version from . import utils @@ -25,15 +29,21 @@ "login", "log", "log_text", + "log_image", + "log_audio", + "log_video", + "define_scalar", # run "SwanLabRun", "has_run", "get_run", + # config + "config", # data "Text", "Audio", - # config - "config", + "Image", + "Video", # utils "utils", ] diff --git a/swanlab/__init__.pyi b/swanlab/__init__.pyi new file mode 100644 index 000000000..dfaf08fec --- /dev/null +++ b/swanlab/__init__.pyi @@ -0,0 +1,409 @@ +""" +swanlab public API type stubs. + +This file is the single source of truth for all type signatures exposed at the +top-level `swanlab` namespace. Add new public symbols here when they are added +to swanlab/__init__.py. +""" + +from typing import Any, List, Mapping, Optional, Union + +from . import utils +from .sdk import config +from .sdk.cmd.init import ConfigLike +from .sdk.internal.run import SwanLabRun +from .sdk.internal.run.transforms import Audio, Image, Text, Video +from .sdk.internal.settings import Settings +from .sdk.typings.run import FinishType, ModeType, ResumeType +from .sdk.typings.run.column import ScalarXAxisType +from .sdk.utils.callbacker import SwanLabCallback + +__version__: str + +__all__ = [ + # cmd + "merge_settings", + "Settings", + "init", + "finish", + "login", + "log", + "log_text", + "log_image", + "log_audio", + "log_video", + "define_scalar", + # run + "SwanLabRun", + "has_run", + "get_run", + # data + "Text", + "Audio", + "Image", + "Video", + # config + "config", + # utils + "utils", +] + +# ── lifecycle ────────────────────────────────────────────────────────────────── + +def init( + *, + reinit: Optional[bool] = None, + logdir: Optional[str] = None, + mode: Optional[ModeType] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, + public: Optional[bool] = None, + name: Optional[str] = None, + color: Optional[str] = None, + description: Optional[str] = None, + job_type: Optional[str] = None, + group: Optional[str] = None, + tags: Optional[List[str]] = None, + id: Optional[str] = None, + resume: Optional[Union[ResumeType, bool]] = None, + config: Optional[ConfigLike] = None, + settings: Optional[Settings] = None, + callbacks: Optional[List[SwanLabCallback]] = None, + **kwargs: Any, +) -> SwanLabRun: + """Initialize a new SwanLab run to track experiments. + + This function starts a new run for logging metrics, artifacts, and metadata. + After calling this, use `swanlab.log()` to log data and `swanlab.finish()` to + close the run. SwanLab automatically finishes runs at program exit. + + :param reinit: If True, finish the current run before starting a new one. Defaults to False. + :param logdir: Directory to store logs. Defaults to "./swanlog". + :param mode: Run mode. Options: "cloud" (sync to cloud), "local" (local only), + "offline" (save locally for later sync), "disabled" (no logging). Defaults to "cloud". + :param workspace: Workspace or organization name. Defaults to current user. + :param project: Project name. Defaults to current directory name. + :param public: Make project publicly visible (cloud mode only). Defaults to False. + :param name: Experiment name. Auto-generated if not provided. + :param color: Experiment color for visualization. Auto-generated if not provided. + :param description: Experiment description. + :param job_type: Job type label (e.g., "train", "eval"). + :param group: Group name for organizing related experiments. + :param tags: List of tags for categorizing experiments. + :param id: Run ID for resuming a previous run (cloud mode only). + :param resume: Resume behavior. Options: "must" (must resume), "allow" (resume if exists), + "never" (always create new). Defaults to "never". + :param config: Experiment configuration dict or path to config file (JSON/YAML). + :param settings: Custom Settings object for advanced configuration. + :param callbacks: List of callback functions triggered on run events. + :return: The initialized SwanLabRun object. + :raises RuntimeError: If a run is already active and reinit=False. + + Examples: + + Basic local run: + + >>> import swanlab + >>> swanlab.init(mode="local", project="my_project") + >>> swanlab.log({"loss": 0.5}) + >>> swanlab.finish() + + Cloud run with configuration: + + >>> import swanlab + >>> swanlab.login(api_key="your_key") + >>> swanlab.init( + ... mode="cloud", + ... project="image_classification", + ... name="resnet50_experiment", + ... config={"lr": 0.001, "batch_size": 32} + ... ) + >>> swanlab.log({"accuracy": 0.95}) + >>> swanlab.finish() + """ + ... + +def finish(state: FinishType = "success", error: Optional[str] = None) -> None: + """Finish the current run and close the experiment. + + This function safely closes the current run and waits for all logs to be flushed. + SwanLab automatically calls this function at program exit, but you can call it + manually to mark the experiment as completed with a specific state. + + :param state: Final state of the run. Must be one of: "success", "crashed", "aborted". + Defaults to "success". + :param error: Error message if state is "crashed". Required when state="crashed". + :raises RuntimeError: If called without an active run. + + Examples: + + Finish a successful run: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log({"loss": 0.5}) + >>> swanlab.finish() + + Mark run as crashed with error message: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> try: + ... raise ValueError("Training failed") + ... except Exception as e: + ... swanlab.finish(state="crashed", error=str(e)) + """ + ... + +def login( + api_key: Optional[str] = None, + relogin: bool = False, + host: Optional[str] = None, + save: bool = False, + timeout: int = 10, +) -> bool: + """Authenticate with SwanLab Cloud. + + This function authenticates your environment with SwanLab. If already logged in + and `relogin` is False, this function does nothing. Call this before `swanlab.init()` + to use cloud features. + + :param api_key: Your SwanLab API key. If not provided, will attempt to read from + environment or prompt for input. + :param relogin: If True, forces re-authentication and overwrites existing credentials. + Defaults to False. + :param host: Custom API host URL. If not provided, uses the default SwanLab cloud host. + :param save: Whether to save the API key locally for future sessions. Defaults to False. + :param timeout: Network request timeout in seconds. Defaults to 10. + :return: True if login was successful, False otherwise. + :raises RuntimeError: If called while a run is active. + :raises AuthenticationError: If login fails due to invalid credentials or network issues. + + Examples: + + Login with an API key: + + >>> import swanlab + >>> swanlab.login(api_key="your_api_key_here") + >>> swanlab.init(mode="cloud") + + Force re-login and save credentials: + + >>> import swanlab + >>> swanlab.login(api_key="new_api_key", relogin=True, save=True) + """ + ... + +def merge_settings(settings: Union[Settings, dict]) -> None: + """Merge custom settings into the global SwanLab configuration. + + This function allows you to customize SwanLab's behavior before initializing a run. + It must be called before `swanlab.init()`. + + :param settings: Custom settings to merge. Can be either a Settings object or a dict. + :raises RuntimeError: If called while a run is active. + + Examples: + + >>> import swanlab + >>> swanlab.merge_settings({"mode": "local", "logdir": "./my_logs"}) + >>> swanlab.init() + """ + ... + +# ── run access ───────────────────────────────────────────────────────────────── + +def has_run() -> bool: + """Check if there is an active SwanLab run. + + :return: True if a run is currently active, False otherwise. + + Examples: + + >>> import swanlab + >>> if swanlab.has_run(): + ... swanlab.log({"metric": 1.0}) + ... else: + ... print("No active run") + """ + ... + +def get_run() -> SwanLabRun: + """Get the current active SwanLab run. + + :return: The active SwanLabRun instance. + :raises RuntimeError: If no run is currently active. + + Examples: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> run = swanlab.get_run() + >>> print(run.id) + >>> swanlab.finish() + """ + ... + +# ── logging ──────────────────────────────────────────────────────────────────── + +def log(data: Mapping[str, Any], step: Optional[int] = None) -> None: + """Log metrics and data to the current run. + + :param data: Dictionary of metric names and values to log. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + Log multiple metrics: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log({"loss": 0.5, "accuracy": 0.95}) + >>> swanlab.finish() + + Log with explicit step: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log({"loss": 0.5}, step=10) + >>> swanlab.finish() + """ + ... + +def log_text( + key: str, + data: Union[str, Text, List[str], List[Text]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging text data. + + :param key: The key for the text data. + :param data: The text data itself or a Text object. + :param caption: Optional caption for the text data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + Log simple text: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log_text("output", "Training started") + >>> swanlab.finish() + + Log text with caption: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.log_text("prediction", "cat", caption="Model output") + >>> swanlab.finish() + """ + ... + +def log_image( + key: str, + data: Union[Image, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging image data. + + :param key: The key for the image data. + :param data: The image data itself or an Image object. + :param caption: Optional caption for the image data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab, numpy as np + >>> swanlab.init(mode="local") + >>> img = np.zeros((64, 64, 3), dtype=np.uint8) + >>> swanlab.log_image("sample", img, caption="blank image") + >>> swanlab.finish() + """ + ... + +def log_audio( + key: str, + data: Union[Audio, Any, List[Any]], + sample_rate: int = 44100, + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging audio data. + + :param key: The key for the audio data. + :param data: The audio data itself or an Audio object. + :param sample_rate: Sample rate of the audio (used when data is raw numpy array). + :param caption: Optional caption for the audio data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab, numpy as np + >>> swanlab.init(mode="local") + >>> audio = np.zeros((1, 44100), dtype=np.float32) + >>> swanlab.log_audio("sound", audio, sample_rate=44100) + >>> swanlab.finish() + """ + ... + +def log_video( + key: str, + data: Union[Video, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, +) -> None: + """A syntactic sugar for logging video data. + + Currently supported formats: GIF. + + :param key: The key for the video data. + :param data: The video data itself or a Video object. + :param caption: Optional caption for the video data. + :param step: Optional step number. If not provided, auto-increments. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> with open("animation.gif", "rb") as f: + ... swanlab.log_video("rollout", f.read()) + >>> swanlab.finish() + """ + ... + +def define_scalar( + key: str, + name: Optional[str] = None, + color: Optional[str] = None, + x_axis: Optional[ScalarXAxisType] = None, + chart_name: Optional[str] = None, +) -> None: + """Explicitly define a scalar column. + + Call this before logging to customize how a scalar metric is displayed, + such as setting a display name, color, or x-axis type. + + :param key: The key for the scalar column. Supports glob patterns (e.g. "train/*") to match multiple columns at once. + :param name: Optional display name for the scalar column. + :param color: Optional hex color for the scalar line in charts. + :param x_axis: Optional x-axis type. One of "_step", "_relative_time", or a custom key. + :param chart_name: Optional name for the chart group this column belongs to. + :raises RuntimeError: If called without an active run. + + Examples: + + >>> import swanlab + >>> swanlab.init(mode="local") + >>> swanlab.define_scalar("loss", color="#FF5733", x_axis="_step") + >>> swanlab.log({"loss": 0.5}) + >>> swanlab.finish() + """ + ... diff --git a/swanlab/exceptions.py b/swanlab/exceptions.py index 525b2fd98..1ac3216ef 100644 --- a/swanlab/exceptions.py +++ b/swanlab/exceptions.py @@ -9,6 +9,8 @@ from requests.exceptions import HTTPError +__all__ = ["ApiError", "AuthenticationError", "DataStoreError"] + class ApiError(HTTPError): """ diff --git a/swanlab/main.py b/swanlab/main.py deleted file mode 100644 index 04e774a6e..000000000 --- a/swanlab/main.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -@author: cunyue -@file: main.py -@time: 2026/3/5 13:11 -@description: hello world -""" - -from typing import TypedDict - - -def main(): - t: Test = {"a": 2, "b": 2} - print("hello world", t["a"], t["b"]) - - -class Test(TypedDict): - a: int - b: int - - -if __name__ == "__main__": - main() diff --git a/swanlab/sdk/__init__.py b/swanlab/sdk/__init__.py index b966226aa..570350141 100644 --- a/swanlab/sdk/__init__.py +++ b/swanlab/sdk/__init__.py @@ -5,11 +5,10 @@ @description: SwanLab SDK,负责SwanLab库的核心指标上传功能 """ -from swanlab.sdk.cmd.finish import finish from swanlab.sdk.cmd.init import init -from swanlab.sdk.cmd.log import log, log_text from swanlab.sdk.cmd.login import login from swanlab.sdk.cmd.merge_settings import Settings, merge_settings +from swanlab.sdk.cmd.run import define_scalar, finish, log, log_audio, log_image, log_text, log_video from swanlab.sdk.internal.run import SwanLabRun, clear_run, get_run, has_run, set_run from swanlab.sdk.internal.run.config import config @@ -21,6 +20,10 @@ "login", "log", "log_text", + "log_image", + "log_audio", + "log_video", + "define_scalar", "SwanLabRun", "has_run", "get_run", diff --git a/swanlab/sdk/cmd/finish.py b/swanlab/sdk/cmd/finish.py deleted file mode 100644 index 147cbd276..000000000 --- a/swanlab/sdk/cmd/finish.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -@author: cunyue -@file: finish.py -@time: 2026/3/6 21:49 -@description: SwanLab SDK 结束当前运行 -""" - -import atexit -import sys -import traceback -from types import TracebackType -from typing import Optional, Type - -from swanlab.sdk.cmd.helper import with_cmd_lock, with_run -from swanlab.sdk.internal.pkg import console -from swanlab.sdk.internal.run import get_run, has_run -from swanlab.sdk.typings.run import FinishType - - -@with_cmd_lock -@with_run("finish") -def finish(state: FinishType = "success", error: Optional[str] = None): - """Finish the current run and close the experiment. - - This function safely closes the current run and waits for all logs to be flushed. - SwanLab automatically calls this function at program exit, but you can call it - manually to mark the experiment as completed with a specific state. - - :param state: Final state of the run. Must be one of: "success", "crashed", "aborted". - Defaults to "success". - - :param error: Error message if state is "crashed". Required when state="crashed". - - :raises RuntimeError: If called without an active run. - - Examples: - - Finish a successful run: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}) - >>> swanlab.finish() - - Mark run as crashed with error message: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> try: - ... # training code - ... raise ValueError("Training failed") - ... except Exception as e: - ... swanlab.finish(state="crashed", error=str(e)) - - Mark run as aborted: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}) - >>> swanlab.finish(state="aborted") - """ - run = get_run() - run.finish(state, error) - - -def atexit_finish(): - """ - 全局退出时自动结束当前运行,此时代码正常执行完毕 - """ - if not has_run(): - return - console.debug("SwanLab Run is finishing at exit...") - run = get_run() - run.finish() - - -atexit.register(atexit_finish) - - -def swanlab_excepthook(tp: Type[BaseException], val: BaseException, tb: Optional[TracebackType]): - """全局异常捕获,用于将实验标记为 crashed""" - try: - if not has_run(): - return - state: FinishType = "crashed" - if tp is KeyboardInterrupt: - console.info("KeyboardInterrupt by user") - state = "aborted" - else: - console.info("Error happened while training") - # 生成错误堆栈 - full_error_msg = "".join(traceback.format_exception(tp, val, tb)) - - # 打印错误堆栈 - run = get_run() - run.finish(state=state, error=full_error_msg) - - except Exception as e: - console.error(f"SwanLab failed to handle excepthook: {e}", file=sys.stderr) - finally: - # _original_excepthook = sys.excepthook - # 不要用动态保存的 _original_excepthook,直接调用 Python 底层 C 实现的 sys.__excepthook__ - # 确保异常信息被正确打印 - sys.__excepthook__(tp, val, tb) - - -sys.excepthook = swanlab_excepthook diff --git a/swanlab/sdk/cmd/log.py b/swanlab/sdk/cmd/log.py deleted file mode 100644 index 4abc99715..000000000 --- a/swanlab/sdk/cmd/log.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -@author: cunyue -@file: log.py -@time: 2026/3/14 -@description: SwanLab SDK logging methods -""" - -from typing import Any, Mapping, Optional, Union - -from swanlab.sdk.cmd.helper import with_cmd_lock, with_run -from swanlab.sdk.internal.run import get_run -from swanlab.sdk.internal.run.transforms import Text - - -@with_cmd_lock -@with_run("log") -def log(data: Mapping[str, Any], step: Optional[int] = None): - """Log metrics and data to the current run. - - :param data: Dictionary of metric names and values to log. - - :param step: Optional step number. If not provided, auto-increments. - - :raises RuntimeError: If called without an active run. - - Examples: - - Log single metric: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}) - >>> swanlab.finish() - - Log multiple metrics: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5, "accuracy": 0.95}) - >>> swanlab.finish() - - Log with explicit step: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log({"loss": 0.5}, step=10) - >>> swanlab.finish() - """ - run = get_run() - run.log(data, step) - - -@with_cmd_lock -@with_run("log_text") -def log_text(key: str, data: Union[str, Text], caption: Optional[str] = None, step: Optional[int] = None): - """Log text data to the current run. - - :param key: The key for the text data. - - :param data: The text data itself or a Text object. - - :param caption: Optional caption for the text data. - - :param step: Optional step number. If not provided, auto-increments. - - :raises RuntimeError: If called without an active run. - - Examples: - - Log simple text: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log_text("output", "Training started") - >>> swanlab.finish() - - Log text with caption: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log_text("prediction", "cat", caption="Model output") - >>> swanlab.finish() - - Log text with step: - - >>> import swanlab - >>> swanlab.init(mode="local") - >>> swanlab.log_text("status", "Epoch 5 complete", step=5) - >>> swanlab.finish() - """ - run = get_run() - run.log_text(key, data, caption, step) diff --git a/swanlab/sdk/cmd/run.py b/swanlab/sdk/cmd/run.py new file mode 100644 index 000000000..f1b30955e --- /dev/null +++ b/swanlab/sdk/cmd/run.py @@ -0,0 +1,43 @@ +""" +@author: cunyue +@file: run.py +@time: 2026/3/14 +@description: SwanLab SDK run methods + +_make_run_cmd 将 SwanLabRun 上的方法包装为顶层 cmd 函数: + - __wrapped__ 指向 SwanLabRun 原方法,help() / inspect 可追溯 + - __doc__ 复用原方法文档,无需重复维护 + +新增公开方法时,在 SwanLabRun 上实现后,在此文件末尾追加一行: + xxx = _make_run_cmd("xxx") +并同步更新 swanlab/__init__.pyi 中的函数声明。 +""" + +from typing import Any, Callable + +from swanlab.sdk.cmd.helper import with_cmd_lock, with_run +from swanlab.sdk.internal.run import SwanLabRun, get_run + + +def _make_run_cmd(method_name: str) -> Callable: + run_method = getattr(SwanLabRun, method_name) + + @with_cmd_lock + @with_run(method_name) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return getattr(get_run(), method_name)(*args, **kwargs) + + wrapper.__name__ = method_name + wrapper.__wrapped__ = run_method # type: ignore[attr-defined] + wrapper.__doc__ = run_method.__doc__ + return wrapper + + +# ── 每新增一个公开方法,在此追加一行 ────────────────────────────────────────── +log = _make_run_cmd("log") +log_text = _make_run_cmd("log_text") +log_image = _make_run_cmd("log_image") +log_audio = _make_run_cmd("log_audio") +log_video = _make_run_cmd("log_video") +define_scalar = _make_run_cmd("define_scalar") +finish = _make_run_cmd("finish") diff --git a/swanlab/sdk/internal/run/__init__.py b/swanlab/sdk/internal/run/__init__.py index 1d993fb6c..2ca12d481 100644 --- a/swanlab/sdk/internal/run/__init__.py +++ b/swanlab/sdk/internal/run/__init__.py @@ -8,10 +8,14 @@ 3. 触发异步微批处理落盘与回调 """ +import atexit +import sys import threading +import traceback from functools import cached_property, wraps from pathlib import Path -from typing import Any, List, Literal, Mapping, Optional, Union, cast, get_args +from types import TracebackType +from typing import Any, List, Literal, Mapping, Optional, Type, Union, cast, get_args from google.protobuf.timestamp_pb2 import Timestamp @@ -28,7 +32,7 @@ create_unbound_run_config, deactivate_run_config, ) -from swanlab.sdk.internal.run.transforms import Text, normalize_media_input +from swanlab.sdk.internal.run.transforms import Audio, Image, Text, Video, normalize_media_input from swanlab.sdk.typings.run import FinishType from swanlab.sdk.typings.run.column import ScalarXAxisType @@ -137,15 +141,64 @@ def __init__(self, ctx: RunContext): # 设置全局运行实例 set_run(self) - - # 绑定日志文件 + # 注册退出钩子 + self._sys_origin_excepthook = sys.excepthook + atexit.register(self._atexit_cleanup) + sys.excepthook = self._excepthook + # 绑定日志文件,运行正式开始 if self._ctx.config.settings.mode != "disabled": log.bindfile(self._ctx.debug_dir) # ---------------------------------- - # 属性 (Properties) + # 私有钩子 # ---------------------------------- + def _atexit_cleanup(self) -> None: + """程序正常退出时自动结束当前运行""" + if self._state != "running": + return + console.debug("SwanLab Run is finishing at exit...") + self.finish() + + def _excepthook( + self, + tp: Type[BaseException], + val: BaseException, + tb: Optional[TracebackType], + ) -> None: + """全局异常捕获,将实验标记为 crashed 或 aborted""" + try: + if self._state != "running": + return + state: FinishType = "crashed" + if tp is KeyboardInterrupt: + console.info("KeyboardInterrupt by user") + state = "aborted" + else: + console.info("Error happened while training") + full_error_msg = "".join(traceback.format_exception(tp, val, tb)) + self.finish(state=state, error=full_error_msg) + except Exception as e: + console.error(f"SwanLab failed to handle excepthook: {e}") + finally: + sys.__excepthook__(tp, val, tb) + + def _cleanup(self): + """ + 清除副作用 + """ + # 取消钩子 + console.debug("Cleanup system hook...") + atexit.unregister(self._atexit_cleanup) + sys.excepthook = self._sys_origin_excepthook + # 清理全局运行实例 + console.debug("Cleanup global instance...") + clear_run() + deactivate_run_config() + console.debug("Clean & tidy! ciallo ( ∠・ω< ) ~ ★") + # 释放日志,本次运行结束 + log.reset() + @cached_property def id(self) -> str: """ @@ -259,27 +312,95 @@ def log_text( @with_lock @with_run - def define_scalar( + def log_image( self, key: str, - name: Optional[str] = None, - color: Optional[str] = None, - x_axis: Optional[ScalarXAxisType] = None, - chart_name: Optional[str] = None, + data: Union[Image, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, + ): + """ + A syntactic sugar for logging image data. + + :param key: The key for the image data. + + :param data: The image data itself or an Image object. + + :param caption: Optional caption for the image data. + + :param step: Optional step for the image data. + """ + normalized_data = normalize_media_input(Image, data, caption=caption) + self.log({key: normalized_data}, step=step) + + @with_lock + @with_run + def log_audio( + self, + key: str, + data: Union[Audio, Any, List[Any]], + sample_rate: int = 44100, + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, + ): + """ + A syntactic sugar for logging audio data. + + :param key: The key for the audio data. + + :param data: The audio data itself or an Audio object. + + :param sample_rate: Sample rate of the audio (used when data is raw numpy array). + + :param caption: Optional caption for the audio data. + + :param step: Optional step for the audio data. + """ + normalized_data = normalize_media_input(Audio, data, caption=caption, sample_rate=sample_rate) + self.log({key: normalized_data}, step=step) + + @with_lock + @with_run + def log_video( + self, + key: str, + data: Union[Video, Any, List[Any]], + caption: Optional[Union[str, List[str]]] = None, + step: Optional[int] = None, ): """ - Explicitly define a scalar column. + A syntactic sugar for logging video data. - :param key: The key for the scalar column. + :param key: The key for the video data. - :param name: Optional name for the scalar column. + :param data: The video data itself or a Video object. - :param color: Optional color for the scalar column. + :param caption: Optional caption for the video data. - :param x_axis: Optional x-axis for the scalar column. + :param step: Optional step for the video data. + """ + normalized_data = normalize_media_input(Video, data, caption=caption) + self.log({key: normalized_data}, step=step) - :param chart_name: Optional name for the chart. + @with_lock + @with_run + def define_scalar( + self, + key: str, + name: Optional[str] = None, + color: Optional[str] = None, + x_axis: Optional[ScalarXAxisType] = None, + chart_name: Optional[str] = None, + ): + """ + 手动定义一个标量列 + :param key: 标量列的键,支持通配符(如 "train/*")以匹配多个列 + :param name: 标量列的可选显示名称 + :param color: 标量列的可选颜色 + :param x_axis: 标量列的可选 x 轴类型 + :param chart_name: 标量列所属的可选图表名称 """ + # TODO: 实现 glob 匹配逻辑 if not (this_key := fmt.safe_validate_key(key)): return console.error( f"Invalid key for define scalar: {key}, please use valid characters (alphanumeric, '.', '-', '/') and avoid special characters." @@ -337,12 +458,8 @@ def finish(self, state: FinishType = "success", error: Optional[str] = None): self._emitter.emit(RunFinishEvent(state=this_state, error=error, timestamp=ts)) # 阻塞主线程,等待后台队列消费完毕 self._consumer.join() - # 清理全局运行实例 - clear_run() - console.debug(f"Run finished with state: {state}") - # 释放一些资源 - log.reset() - deactivate_run_config() + console.debug(f"SwanLab Run has finished with state: {self._state}, cleanup...") + self._cleanup() _current_run: Optional[SwanLabRun] = None diff --git a/swanlab/sdk/internal/run/transforms/__init__.py b/swanlab/sdk/internal/run/transforms/__init__.py index 9c195f6db..7f48ed784 100644 --- a/swanlab/sdk/internal/run/transforms/__init__.py +++ b/swanlab/sdk/internal/run/transforms/__init__.py @@ -10,10 +10,12 @@ from swanlab.sdk.internal.context import TransformMedia from .audio import Audio +from .image import Image from .scalar import Scalar from .text import Text +from .video import Video -__all__ = ["Text", "Scalar", "Audio", "normalize_media_input"] +__all__ = ["Text", "Scalar", "Audio", "Image", "Video", "normalize_media_input"] def normalize_media_input( @@ -56,10 +58,7 @@ def normalize_media_input( # 构造媒体对象列表 result = [] for i, item in enumerate(data_list): - if isinstance(item, media_cls): - result.append(item) - else: - item_kwargs = {k: v[i] for k, v in normalized_kwargs.items()} - result.append(media_cls(*[item], **item_kwargs)) + item_kwargs = {k: v[i] for k, v in normalized_kwargs.items()} + result.append(media_cls(*[item], **item_kwargs)) return result diff --git a/swanlab/sdk/internal/run/transforms/image/__init__.py b/swanlab/sdk/internal/run/transforms/image/__init__.py index 1f61c7499..4526f6ec0 100644 --- a/swanlab/sdk/internal/run/transforms/image/__init__.py +++ b/swanlab/sdk/internal/run/transforms/image/__init__.py @@ -1,6 +1,174 @@ """ @author: cunyue @file: __init__.py -@time: 2026/3/11 19:17 +@time: 2026/3/15 @description: 图像处理模块 """ + +import hashlib +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Union + +from google.protobuf.timestamp_pb2 import Timestamp + +from swanlab import vendor +from swanlab.proto.swanlab.metric.column.v1.column_pb2 import ColumnType +from swanlab.proto.swanlab.metric.data.v1.data_pb2 import DataRecord +from swanlab.proto.swanlab.metric.data.v1.media.image_pb2 import ImageItem, ImageValue +from swanlab.sdk.internal.context import TransformMedia +from swanlab.sdk.internal.pkg.fs import safe_write + +ACCEPT_FORMAT = ["png", "jpg", "jpeg", "bmp"] + + +def _is_torch_tensor(obj) -> bool: + """通过类型名检测 PyTorch Tensor,避免强制导入 torch""" + typename = obj.__class__.__module__ + "." + obj.__class__.__name__ + return typename.startswith("torch.") and ("Tensor" in typename or "Variable" in typename) + + +def _resize(image: "vendor.PIL.Image.Image", size) -> "vendor.PIL.Image.Image": + """按 size 参数缩放图像""" + if size is None: + return image + if isinstance(size, int): + if max(image.size) > size: + image.thumbnail((size, size)) + return image + if isinstance(size, (list, tuple)): + w, h = (tuple(size) + (None,))[:2] + if w is not None and h is not None: + return image.resize((int(w), int(h))) + if w is not None: + return image.resize((int(w), int(image.size[1] * w / image.size[0]))) + if h is not None: + return image.resize((int(image.size[0] * h / image.size[1]), int(h))) + raise ValueError("size must be an int, or a list/tuple with 1-2 elements") + + +class Image(TransformMedia): + def __init__( + self, + data_or_path: Union[ + "Image", + str, + "vendor.PIL.Image.Image", + "vendor.np.ndarray", + "vendor.torch.Tensor", + "vendor.matplotlib.figure.Figure", + ], + mode: Optional[str] = None, + caption: Optional[str] = None, + file_type: Optional[str] = None, + size: Optional[Union[int, list, tuple]] = None, + ): + """Image class constructor + + Parameters + ---------- + data_or_path: str, PIL.Image.Image, numpy.ndarray, torch.Tensor, matplotlib.figure.Figure, or Image + Path to an image file (PNG/JPG/JPEG/BMP; GIF is not supported), a PIL Image, + numpy array (shape: (H, W) or (H, W, 3/4)), torch.Tensor, matplotlib figure, + or another Image instance (nesting). + mode: str, optional + PIL mode applied when converting to PIL.Image (e.g. 'RGB', 'L'). + caption: str, optional + Caption for the image. + file_type: str, optional + Output file format. One of ['png', 'jpg', 'jpeg', 'bmp']. Defaults to 'png'. + size: int, list, or tuple, optional + Resize policy: + - int: maximum side length (aspect-ratio preserved via thumbnail). + - (w, h): exact target size. + - (w, None) / (None, h): fix one dimension, scale the other proportionally. + - None: no resize. + """ + super().__init__() + + # 套娃加载 + attrs = self._unwrap(data_or_path) + if attrs: + self.buffer: BytesIO = attrs["buffer"] + self.file_type: str = attrs["file_type"] + self.caption: Optional[str] = caption if caption is not None else attrs.get("caption") + return + + # 校验 file_type + ft = (file_type or "png").lower() + if ft not in ACCEPT_FORMAT: + raise ValueError(f"Unsupported file_type '{ft}'. Accepted: {ACCEPT_FORMAT}") + self.file_type = ft + + # ---------- 各类型输入 → PIL Image ---------- + # 考虑到懒加载限制我们一般使用鸭子类型判断,而不是 isinstance + PILImage = vendor.PIL.Image + # 1. 文件路径 + if isinstance(data_or_path, str): + if data_or_path.lower().endswith(".gif"): + raise TypeError("GIF images are not supported. Please convert to PNG or JPG first.") + try: + pil_img = PILImage.open(data_or_path) + except Exception as e: + raise ValueError(f"Failed to open image file: {data_or_path!r}") from e + if getattr(pil_img, "format", None) == "GIF": + raise TypeError("GIF images are not supported. Please convert to PNG or JPG first.") + image_data = pil_img.convert(mode) + # 2. PIL Image + elif isinstance(data_or_path, PILImage.Image): + if getattr(data_or_path, "format", None) == "GIF": + raise TypeError("GIF images are not supported. Please convert to PNG or JPG first.") + image_data = data_or_path.convert(mode) + # 3. PyTorch Tensor + elif _is_torch_tensor(data_or_path): + t = data_or_path + if hasattr(t, "requires_grad") and t.requires_grad: # type: ignore[union-attr] + t = t.detach() # type: ignore[union-attr] + t = vendor.torchvision.utils.make_grid(t, normalize=True) # type: ignore[arg-type] + image_data = PILImage.fromarray(t.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(), mode=mode) + # 4. Matplotlib Figure + elif hasattr(data_or_path, "savefig"): + try: + buf = BytesIO() + data_or_path.savefig(buf, format=self.file_type) # type: ignore[union-attr] + buf.seek(0) + image_data = PILImage.open(buf).convert(mode) + buf.close() + except Exception as e: + raise TypeError("Failed to convert matplotlib figure to image") from e + # 5. Numpy Array + elif isinstance(data_or_path, vendor.np.ndarray): + arr = data_or_path + if arr.ndim == 2 or (arr.ndim == 3 and arr.shape[2] in (3, 4)): + image_data = PILImage.fromarray(vendor.np.clip(arr, 0, 255).astype(vendor.np.uint8), mode=mode) + else: + raise TypeError(f"Invalid numpy array shape for Image: expected (H, W) or (H, W, 3/4), got {arr.shape}") + + else: + raise TypeError( + f"Unsupported image type: {type(data_or_path).__name__}. " + "Please provide a file path, PIL.Image, numpy.ndarray, torch.Tensor, or matplotlib figure." + ) + + image_data = _resize(image_data, size) + self.buffer = BytesIO() + save_fmt = "jpeg" if self.file_type == "jpg" else self.file_type + image_data.save(self.buffer, format=save_fmt) + self.caption = caption + + @classmethod + def column_type(cls) -> ColumnType: + return ColumnType.COLUMN_TYPE_IMAGE + + @classmethod + def build_data_record(cls, *, key: str, step: int, timestamp: Timestamp, data: List[ImageItem]) -> DataRecord: + return DataRecord( + key=key, step=step, timestamp=timestamp, type=cls.column_type(), images=ImageValue(items=data) + ) + + def transform(self, *, step: int, path: Path) -> ImageItem: + content = self.buffer.getvalue() + sha256 = hashlib.sha256(content).hexdigest() + filename = f"{step:03d}-{sha256[:8]}.{self.file_type}" + safe_write(path / filename, content, mode="wb") + return ImageItem(filename=filename, sha256=sha256, size=len(content), caption=self.caption or "") diff --git a/swanlab/sdk/internal/run/transforms/video/__init__.py b/swanlab/sdk/internal/run/transforms/video/__init__.py new file mode 100644 index 000000000..d282e8b1e --- /dev/null +++ b/swanlab/sdk/internal/run/transforms/video/__init__.py @@ -0,0 +1,120 @@ +""" +@author: cunyue +@file: __init__.py +@time: 2026/3/15 +@description: 视频处理模块,暂时只支持 GIF +""" + +import hashlib +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Union + +from google.protobuf.timestamp_pb2 import Timestamp + +from swanlab.proto.swanlab.metric.column.v1.column_pb2 import ColumnType +from swanlab.proto.swanlab.metric.data.v1.data_pb2 import DataRecord +from swanlab.proto.swanlab.metric.data.v1.media.video_pb2 import VideoItem, VideoValue +from swanlab.sdk.internal.context import TransformMedia +from swanlab.sdk.internal.pkg.fs import safe_write + +# 各格式的魔数校验表,新增格式时在此追加 +# format → (magic_bytes, ...) +_FORMAT_MAGIC: dict[str, tuple[bytes, ...]] = { + "gif": (b"GIF87a", b"GIF89a"), +} + +# 路径后缀 → 格式名 +_EXT_TO_FORMAT: dict[str, str] = { + ".gif": "gif", +} + + +def _detect_format_by_magic(data: bytes) -> Optional[str]: + """根据魔数推断格式,无法识别则返回 None""" + for fmt, magics in _FORMAT_MAGIC.items(): + if any(data.startswith(m) for m in magics): + return fmt + return None + + +class Video(TransformMedia): + def __init__( + self, + data_or_path: Union["Video", str, bytes, BytesIO], + caption: Optional[str] = None, + ): + """Video class constructor + + 目前支持的格式:GIF。 + + Parameters + ---------- + data_or_path: str, bytes, BytesIO, or Video + Path to a supported video file, raw video bytes, a BytesIO containing + video data, or another Video instance (nesting). + caption: str, optional + Caption for the video. + """ + super().__init__() + + # 套娃加载 + attrs = self._unwrap(data_or_path) + if attrs: + self.buffer: BytesIO = attrs["buffer"] + self.format: str = attrs["format"] + self.caption: Optional[str] = caption if caption is not None else attrs.get("caption") + return + + # 1. 文件路径 + if isinstance(data_or_path, str): + ext = Path(data_or_path).suffix.lower() + if ext not in _EXT_TO_FORMAT: + supported = ", ".join(_EXT_TO_FORMAT) + raise TypeError(f"Unsupported file extension '{ext}'. Supported: {supported}") + try: + with open(data_or_path, "rb") as f: + raw = f.read() + except OSError as e: + raise ValueError(f"Failed to open file: {data_or_path!r}") from e + fmt = _detect_format_by_magic(raw) + if fmt is None: + raise TypeError(f"File '{data_or_path}' does not match any known video format magic number.") + self.format = fmt + + # 2. bytes 或 BytesIO + elif isinstance(data_or_path, (bytes, BytesIO)): + raw = data_or_path if isinstance(data_or_path, bytes) else data_or_path.read() + fmt = _detect_format_by_magic(raw) + if fmt is None: + supported = ", ".join(_FORMAT_MAGIC) + raise TypeError(f"Cannot detect video format from bytes. Supported formats: {supported}") + self.format = fmt + + # 3. 其他类型 + else: + supported = ", ".join(_EXT_TO_FORMAT) + raise TypeError( + f"Unsupported type: {type(data_or_path).__name__}. " + f"Please provide a file path ({supported}), bytes, or BytesIO." + ) + + self.buffer = BytesIO(raw) + self.caption = caption + + @classmethod + def column_type(cls) -> ColumnType: + return ColumnType.COLUMN_TYPE_VIDEO + + @classmethod + def build_data_record(cls, *, key: str, step: int, timestamp: Timestamp, data: List[VideoItem]) -> DataRecord: + return DataRecord( + key=key, step=step, timestamp=timestamp, type=cls.column_type(), videos=VideoValue(items=data) + ) + + def transform(self, *, step: int, path: Path) -> VideoItem: + content = self.buffer.getvalue() + sha256 = hashlib.sha256(content).hexdigest() + filename = f"{step:03d}-{sha256[:8]}.{self.format}" + safe_write(path / filename, content, mode="wb") + return VideoItem(filename=filename, sha256=sha256, size=len(content), caption=self.caption or "") diff --git a/swanlab/vendor/__init__.py b/swanlab/vendor/__init__.py index dc3c3e037..75284a95a 100644 --- a/swanlab/vendor/__init__.py +++ b/swanlab/vendor/__init__.py @@ -15,13 +15,17 @@ import boto3 import imageio import matplotlib + import matplotlib.figure import moviepy import numpy as np import pandas as pd import PIL + import PIL.Image import rdkit import soundfile import swanboard + import torch + import torchvision # 2. Expose the available modules for IDE auto-completion @@ -35,6 +39,8 @@ "soundfile", "swanboard", "boto3", + "torch", + "torchvision", # these are extra dependencies which are not in [project.optional-dependencies] "pd", ] @@ -50,6 +56,8 @@ "soundfile": "soundfile", "swanboard": "swanboard", "boto3": "boto3", + "torch": "torch", + "torchvision": "torchvision", # these are extra dependencies which are not in [project.optional-dependencies] "pd": "pandas", } @@ -71,8 +79,15 @@ "boto3": "s3", } +# 5. Submodule imports: some packages require submodules to be imported explicitly +# so their attributes are accessible (e.g. PIL.Image must be imported for PIL.Image to work) +_SUBMODULE_IMPORTS = { + "PIL": ["PIL.Image"], + "matplotlib": ["matplotlib.figure"], +} + -# 5. Module-level __getattr__ for lazy loading (PEP 562) +# 6. Module-level __getattr__ for lazy loading (PEP 562) def __getattr__(name: str) -> Any: if name in _LAZY_IMPORTS: module_path = _LAZY_IMPORTS[name] @@ -86,6 +101,10 @@ def __getattr__(name: str) -> Any: # Handle direct third-party library imports obj = importlib.import_module(module_path) + # Import required submodules so their attributes are accessible on the parent package + for submodule_path in _SUBMODULE_IMPORTS.get(name, []): + importlib.import_module(submodule_path) + # Cache the imported object in the module's global namespace globals()[name] = obj return obj diff --git a/tests/unit/sdk/cmd/finish/test_finish.py b/tests/unit/sdk/cmd/finish/test_finish.py deleted file mode 100644 index e101c57ed..000000000 --- a/tests/unit/sdk/cmd/finish/test_finish.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -@author: cunyue -@file: test_finish.py -@time: 2026/3/14 -@description: 测试 swanlab.sdk.cmd.finish 中各函数的单元行为(均 mock 依赖,不启动真实 Run) -""" - -import sys -from unittest.mock import ANY, MagicMock - -from swanlab.sdk.cmd.finish import atexit_finish, finish, swanlab_excepthook - - -def _make_exc_info(exc: BaseException): - """辅助:构造 (tp, val, tb) 三元组,用于测试 excepthook""" - try: - raise exc - except BaseException: - tp, val, tb = sys.exc_info() - assert tp is not None and val is not None - return tp, val, tb - - -class TestFinishFunction: - def test_finish_calls_run_finish(self, monkeypatch): - """应将 state 和 error 透传给 run.finish()""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - - finish(state="crashed", error="something went wrong") - - mock_run.finish.assert_called_once_with("crashed", "something went wrong") - - def test_finish_default_state_is_success(self, monkeypatch): - """finish() 不传参时,默认 state 为 'success',error 为 None""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - - finish() - - mock_run.finish.assert_called_once_with("success", None) - - -class TestAtexitFinish: - def test_atexit_finish_no_run(self, monkeypatch): - """无活跃 Run 时,atexit_finish 直接返回,不调用 get_run""" - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: False) - mock_get_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", mock_get_run) - - atexit_finish() - - mock_get_run.assert_not_called() - - def test_atexit_finish_calls_run_finish(self, monkeypatch): - """有活跃 Run 时,atexit_finish 应调用 run.finish()""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - - atexit_finish() - - mock_run.finish.assert_called_once() - - -class TestSwanlabExcepthook: - def test_excepthook_keyboard_interrupt(self, monkeypatch): - """KeyboardInterrupt 异常 → run.finish(state='aborted', ...)""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - monkeypatch.setattr(sys, "__excepthook__", MagicMock()) - - tp, val, tb = _make_exc_info(KeyboardInterrupt()) - swanlab_excepthook(tp, val, tb) - - mock_run.finish.assert_called_once_with(state="aborted", error=ANY) - - def test_excepthook_generic_exception(self, monkeypatch): - """普通异常 → run.finish(state='crashed'),error 包含完整 traceback""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", lambda: mock_run) - monkeypatch.setattr(sys, "__excepthook__", MagicMock()) - - tp, val, tb = _make_exc_info(RuntimeError("boom")) - swanlab_excepthook(tp, val, tb) - - call_kwargs = mock_run.finish.call_args.kwargs - assert call_kwargs["state"] == "crashed" - assert "boom" in call_kwargs["error"] - - def test_excepthook_no_run(self, monkeypatch): - """无活跃 Run 时,不调用 run.finish,但不抛出异常""" - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: False) - mock_get_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.finish.get_run", mock_get_run) - monkeypatch.setattr(sys, "__excepthook__", MagicMock()) - - tp, val, tb = _make_exc_info(RuntimeError("no run")) - swanlab_excepthook(tp, val, tb) - - mock_get_run.assert_not_called() - - def test_excepthook_always_calls_original_hook(self, monkeypatch): - """无论是否有活跃 Run,sys.__excepthook__ 必须被调用一次""" - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", lambda: False) - mock_original = MagicMock() - monkeypatch.setattr(sys, "__excepthook__", mock_original) - - tp, val, tb = _make_exc_info(RuntimeError("test")) - swanlab_excepthook(tp, val, tb) - - mock_original.assert_called_once_with(tp, val, tb) - - def test_excepthook_internal_error_doesnt_crash(self, monkeypatch): - """excepthook 内部逻辑出错时,不应向上抛出异常,且仍调用 sys.__excepthook__""" - # 模拟 has_run 本身抛出异常(触发 except 块) - monkeypatch.setattr("swanlab.sdk.cmd.finish.has_run", MagicMock(side_effect=Exception("internal boom"))) - # mock console 以隔离 console.error 的实现细节 - monkeypatch.setattr("swanlab.sdk.cmd.finish.console", MagicMock()) - mock_original = MagicMock() - monkeypatch.setattr(sys, "__excepthook__", mock_original) - - tp, val, tb = _make_exc_info(RuntimeError("outer")) - - # 不应抛出 - swanlab_excepthook(tp, val, tb) - - # __excepthook__ 仍须被调用(finally 块保证) - mock_original.assert_called_once_with(tp, val, tb) diff --git a/tests/unit/sdk/cmd/finish/test_finish_e2e.py b/tests/unit/sdk/cmd/finish/test_finish_e2e.py index 12a4aea11..f916d80ff 100644 --- a/tests/unit/sdk/cmd/finish/test_finish_e2e.py +++ b/tests/unit/sdk/cmd/finish/test_finish_e2e.py @@ -9,8 +9,8 @@ import pytest import swanlab -from swanlab.sdk.cmd.finish import finish from swanlab.sdk.cmd.init import init +from swanlab.sdk.cmd.run import finish from swanlab.sdk.internal.run import has_run diff --git a/tests/unit/sdk/cmd/test_log.py b/tests/unit/sdk/cmd/test_log.py deleted file mode 100644 index f64a9d5a0..000000000 --- a/tests/unit/sdk/cmd/test_log.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -@author: cunyue -@file: test_log.py -@time: 2026/3/14 -@description: 测试 swanlab.sdk.cmd.log 中的函数 -""" - -from unittest.mock import MagicMock - -from swanlab.sdk.cmd.log import log, log_text - - -class TestLog: - def test_log_calls_run_log(self, monkeypatch): - """log() 应调用 run.log() 并传递参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log({"loss": 0.5, "accuracy": 0.95}) - - mock_run.log.assert_called_once_with({"loss": 0.5, "accuracy": 0.95}, None) - - def test_log_with_step(self, monkeypatch): - """log() 应正确传递 step 参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log({"loss": 0.5}, step=10) - - mock_run.log.assert_called_once_with({"loss": 0.5}, 10) - - -class TestLogText: - def test_log_text_calls_run_log_text(self, monkeypatch): - """log_text() 应调用 run.log_text() 并传递参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log_text("output", "Training started") - - mock_run.log_text.assert_called_once_with("output", "Training started", None, None) - - def test_log_text_with_caption(self, monkeypatch): - """log_text() 应正确传递 caption 参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log_text("prediction", "cat", caption="Model output") - - mock_run.log_text.assert_called_once_with("prediction", "cat", "Model output", None) - - def test_log_text_with_step(self, monkeypatch): - """log_text() 应正确传递 step 参数""" - mock_run = MagicMock() - monkeypatch.setattr("swanlab.sdk.cmd.helper.has_run", lambda: True) - monkeypatch.setattr("swanlab.sdk.cmd.log.get_run", lambda: mock_run) - - log_text("status", "Epoch complete", step=5) - - mock_run.log_text.assert_called_once_with("status", "Epoch complete", None, 5) diff --git a/tests/unit/sdk/internal/run/data/test_media_transform.py b/tests/unit/sdk/internal/run/data/test_media_transform.py index 3ff3d9b46..ab47e1d3e 100644 --- a/tests/unit/sdk/internal/run/data/test_media_transform.py +++ b/tests/unit/sdk/internal/run/data/test_media_transform.py @@ -18,12 +18,22 @@ import swanlab.sdk.internal.run.transforms # noqa: F401 # type: ignore — 触发所有子类注册 from swanlab.sdk.internal.context import TransformMedia from swanlab.sdk.internal.run.transforms.audio import Audio +from swanlab.sdk.internal.run.transforms.image import Image from swanlab.sdk.internal.run.transforms.text import Text +from swanlab.sdk.internal.run.transforms.video import Video + +# 最小合法 GIF89a(1×1 像素) +_GIF_1X1 = ( + b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00" + b"!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" +) # 注册表:TransformMedia 子类 → 无参工厂(每次调用返回内容相同的新实例) MEDIA_FACTORIES = { Audio: lambda: Audio(np.zeros((1, 4410), dtype=np.float32), sample_rate=44100), + Image: lambda: Image(np.zeros((10, 10, 3), dtype=np.uint8)), Text: lambda: Text(content="hello world"), + Video: lambda: Video(_GIF_1X1), } diff --git a/tests/unit/sdk/internal/run/data/test_video.py b/tests/unit/sdk/internal/run/data/test_video.py new file mode 100644 index 000000000..34cb81df8 --- /dev/null +++ b/tests/unit/sdk/internal/run/data/test_video.py @@ -0,0 +1,170 @@ +""" +@author: cunyue +@file: test_video.py +@time: 2026/3/15 +@description: 视频处理模块单元测试 +""" + +import hashlib +from io import BytesIO +from pathlib import Path + +import pytest +from google.protobuf.timestamp_pb2 import Timestamp + +from swanlab.proto.swanlab.metric.data.v1.media.video_pb2 import VideoItem +from swanlab.sdk.internal.run.transforms.video import Video + +# 最小合法 GIF89a(1×1 像素) +GIF_BYTES = ( + b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00" + b"!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" +) + + +@pytest.fixture +def gif_file(tmp_path: Path) -> str: + """写一个临时 GIF 文件,返回其路径字符串""" + path = tmp_path / "test.gif" + path.write_bytes(GIF_BYTES) + return str(path) + + +# ---------------------------------- 构造测试 ---------------------------------- + + +class TestVideoInit: + def test_from_bytes(self): + v = Video(GIF_BYTES) + assert v.format == "gif" + assert v.buffer.getvalue() == GIF_BYTES + assert v.caption is None + + def test_from_bytesio(self): + v = Video(BytesIO(GIF_BYTES)) + assert v.format == "gif" + assert v.buffer.getvalue() == GIF_BYTES + + def test_from_path(self, gif_file): + v = Video(gif_file) + assert v.format == "gif" + assert v.buffer.getvalue() == GIF_BYTES + + def test_caption_stored(self): + v = Video(GIF_BYTES, caption="my clip") + assert v.caption == "my clip" + + def test_caption_none_by_default(self): + v = Video(GIF_BYTES) + assert v.caption is None + + +# ---------------------------------- 错误处理测试 ---------------------------------- + + +class TestVideoInitErrors: + def test_invalid_extension_raises(self, tmp_path): + p = tmp_path / "clip.mp4" + p.write_bytes(b"\x00" * 16) + with pytest.raises(TypeError, match="Unsupported file extension"): + Video(str(p)) + + def test_nonexistent_path_raises(self): + with pytest.raises(ValueError, match="Failed to open file"): + Video("/nonexistent/path/clip.gif") + + def test_bad_magic_in_file_raises(self, tmp_path): + """文件扩展名合法但内容不是 GIF""" + p = tmp_path / "fake.gif" + p.write_bytes(b"\xff\xd8\xff\xe0" + b"\x00" * 16) # JPEG 魔数 + with pytest.raises(TypeError, match="magic number"): + Video(str(p)) + + def test_bad_magic_in_bytes_raises(self): + with pytest.raises(TypeError, match="Cannot detect video format"): + Video(b"\x00\x01\x02\x03\x04\x05\x06\x07") + + def test_unsupported_type_raises(self): + with pytest.raises(TypeError, match="Unsupported type"): + Video(12345) # type: ignore + + +# ---------------------------------- 套娃加载测试 ---------------------------------- + + +class TestVideoNesting: + def test_wrap_copies_buffer_and_format(self): + inner = Video(GIF_BYTES, caption="inner") + outer = Video(inner) + assert outer.buffer is inner.buffer + assert outer.format == inner.format + assert outer.caption == "inner" + + def test_outer_caption_overrides_inner(self): + inner = Video(GIF_BYTES, caption="inner") + outer = Video(inner, caption="outer") + assert outer.caption == "outer" + + def test_inner_caption_used_when_outer_none(self): + inner = Video(GIF_BYTES, caption="inner") + outer = Video(inner, caption=None) + assert outer.caption == "inner" + + +# ---------------------------------- column_type / build_data_record 测试 ---------------------------------- + + +class TestVideoColumnType: + def test_column_type(self): + from swanlab.proto.swanlab.metric.column.v1.column_pb2 import ColumnType + + assert Video.column_type() == ColumnType.COLUMN_TYPE_VIDEO + + +class TestVideoBuildDataRecord: + def test_build_data_record_structure(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + ts = Timestamp() + record = Video.build_data_record(key="rollout", step=1, timestamp=ts, data=[item]) + + assert record.key == "rollout" + assert record.step == 1 + assert len(record.videos.items) == 1 + assert record.videos.items[0].filename == item.filename + + def test_build_data_record_multiple_items(self, tmp_path): + i1 = Video(GIF_BYTES).transform(step=1, path=tmp_path) + i2 = Video(BytesIO(GIF_BYTES)).transform(step=1, path=tmp_path) + ts = Timestamp() + record = Video.build_data_record(key="k", step=1, timestamp=ts, data=[i1, i2]) + assert len(record.videos.items) == 2 + + +# ---------------------------------- transform 特有字段测试 ---------------------------------- + + +class TestVideoTransform: + def test_transform_returns_video_item(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + assert isinstance(item, VideoItem) + + def test_transform_sha256_correct(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + content = (tmp_path / item.filename).read_bytes() + assert item.sha256 == hashlib.sha256(content).hexdigest() + + def test_transform_size_correct(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + assert item.size == len((tmp_path / item.filename).read_bytes()) + + def test_transform_caption_empty_when_none(self, tmp_path): + item = Video(GIF_BYTES).transform(step=1, path=tmp_path) + assert item.caption == "" + + def test_transform_caption_preserved(self, tmp_path): + item = Video(GIF_BYTES, caption="hello").transform(step=1, path=tmp_path) + assert item.caption == "hello" + + def test_transform_format_in_filename(self, tmp_path): + item = Video(GIF_BYTES).transform(step=3, path=tmp_path) + assert item.filename.endswith(".gif") diff --git a/tests/unit/sdk/internal/run/test_finish_hook.py b/tests/unit/sdk/internal/run/test_finish_hook.py new file mode 100644 index 000000000..9be86d83f --- /dev/null +++ b/tests/unit/sdk/internal/run/test_finish_hook.py @@ -0,0 +1,90 @@ +""" +@author: cunyue +@file: test_finish_hook.py +@time: 2026/3/14 +@description: 测试 SwanLabRun._atexit_cleanup / _excepthook 的单元行为(均 mock 依赖,不启动真实 Run) +""" + +import sys +import threading +from unittest.mock import ANY, MagicMock, patch + +from swanlab.sdk.internal.run import SwanLabRun + + +def _make_exc_info(exc: BaseException): + """辅助:构造 (tp, val, tb) 三元组""" + # noinspection PyBroadException + try: + raise exc + except BaseException: + tp, val, tb = sys.exc_info() + assert tp is not None and val is not None + return tp, val, tb + + +def _make_mock_run(state: str = "running") -> MagicMock: + """构造一个最小化的 SwanLabRun 替身""" + mock = MagicMock(spec=SwanLabRun) + mock._state = state + mock._api_lock = threading.RLock() + return mock + + +class TestAtexitCleanup: + def test_no_op_when_not_running(self): + """_state != 'running' 时直接返回,不调用 finish""" + run = _make_mock_run(state="success") + SwanLabRun._atexit_cleanup(run) + run.finish.assert_not_called() + + def test_calls_finish_when_running(self): + """_state == 'running' 时应调用 finish()""" + run = _make_mock_run(state="running") + SwanLabRun._atexit_cleanup(run) + run.finish.assert_called_once() + + +class TestExcepthook: + def test_keyboard_interrupt_calls_aborted(self): + """KeyboardInterrupt → finish(state='aborted', ...)""" + run = _make_mock_run() + with patch("sys.__excepthook__"): + tp, val, tb = _make_exc_info(KeyboardInterrupt()) + SwanLabRun._excepthook(run, tp, val, tb) + run.finish.assert_called_once_with(state="aborted", error=ANY) + + def test_generic_exception_calls_crashed(self): + """普通异常 → finish(state='crashed'),error 包含完整 traceback""" + run = _make_mock_run() + with patch("sys.__excepthook__"): + tp, val, tb = _make_exc_info(RuntimeError("boom")) + SwanLabRun._excepthook(run, tp, val, tb) + call_kwargs = run.finish.call_args.kwargs + assert call_kwargs["state"] == "crashed" + assert "boom" in call_kwargs["error"] + + def test_no_op_when_not_running(self): + """_state != 'running' 时不调用 finish""" + run = _make_mock_run(state="success") + with patch("sys.__excepthook__"): + tp, val, tb = _make_exc_info(RuntimeError("no run")) + SwanLabRun._excepthook(run, tp, val, tb) + run.finish.assert_not_called() + + def test_always_calls_original_excepthook(self): + """无论是否有活跃 Run,sys.__excepthook__ 必须被调用一次""" + run = _make_mock_run(state="success") + with patch("sys.__excepthook__") as mock_original: + tp, val, tb = _make_exc_info(RuntimeError("test")) + SwanLabRun._excepthook(run, tp, val, tb) + mock_original.assert_called_once_with(tp, val, tb) + + def test_internal_error_doesnt_crash(self): + """excepthook 内部出错时不向上抛出,仍调用 sys.__excepthook__""" + run = _make_mock_run() + run.finish.side_effect = Exception("internal boom") + with patch("sys.__excepthook__") as mock_original: + tp, val, tb = _make_exc_info(RuntimeError("outer")) + SwanLabRun._excepthook(run, tp, val, tb) + mock_original.assert_called_once_with(tp, val, tb) diff --git a/tests/unit/sdk/internal/run/test_log_media.py b/tests/unit/sdk/internal/run/test_log_media.py new file mode 100644 index 000000000..b1a8451a8 --- /dev/null +++ b/tests/unit/sdk/internal/run/test_log_media.py @@ -0,0 +1,145 @@ +""" +@author: cunyue +@file: test_run_log_media.py +@time: 2026/3/15 +@description: SwanLabRun.log_text / log_image / log_audio / log_video 的参数化测试 +""" + +import threading +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from swanlab.sdk.internal.run import SwanLabRun +from swanlab.sdk.internal.run.transforms.audio import Audio +from swanlab.sdk.internal.run.transforms.image import Image +from swanlab.sdk.internal.run.transforms.text import Text +from swanlab.sdk.internal.run.transforms.video import Video + +# 最小合法 GIF89a(1×1 像素) +_GIF_1X1 = ( + b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00" + b"!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;" +) + + +def _make_text() -> Text: + return Text(content="hello world") + + +def _make_audio() -> Audio: + return Audio(np.zeros((1, 4410), dtype=np.float32), sample_rate=44100) + + +def _make_image() -> Image: + return Image(np.zeros((10, 10, 3), dtype=np.uint8)) + + +def _make_video() -> Video: + return Video(_GIF_1X1) + + +# ────────────────────────────────────────────── +# Mock run +# ────────────────────────────────────────────── + + +class _MockRun: + """最小化的 SwanLabRun 替身,供非绑定方法测试使用""" + + def __init__(self): + self._state = "running" + self._api_lock = threading.RLock() + self.log = MagicMock() + + +# ────────────────────────────────────────────── +# 参数化用例 +# ────────────────────────────────────────────── + +# (method_name, media_factory, extra_kwargs_for_method) +_MEDIA_LOG_CASES = [ + pytest.param("log_text", _make_text, {}, id="log_text"), + pytest.param("log_image", _make_image, {}, id="log_image"), + pytest.param("log_audio", _make_audio, {"sample_rate": 44100}, id="log_audio"), + pytest.param("log_video", _make_video, {}, id="log_video"), +] + + +def _call_step(method_name, mock_run, key, data, extra, caption=None, step=None): + method = getattr(SwanLabRun, method_name) + return method(mock_run, key, data, **extra, caption=caption, step=step) + + +@pytest.mark.parametrize("method_name,factory,extra", _MEDIA_LOG_CASES) +class TestRunLogMedia: + """log_text / log_image / log_audio / log_video 的通用契约测试 + 没有log_echarts和log_object3D,因为他们的操作比较复杂 + """ + + def test_single_instance_wrapped_in_list(self, method_name, factory, extra): + """传入单个媒体实例 → log 被调用,data 会被包装""" + mock_run = _MockRun() + instance = factory() + _call_step(method_name, mock_run, "key", instance, extra) + mock_run.log.assert_called_once() + log_data = mock_run.log.call_args[0][0] + assert "key" in log_data + assert len(log_data["key"]) == 1 + assert log_data["key"] != [instance] + assert type(log_data["key"][0]) is instance.__class__ + + def test_list_of_instances_passed_through(self, method_name, factory, extra): + """传入实例列表 → log 被调用,data 保持列表不变""" + mock_run = _MockRun() + instances = [factory(), factory()] + _call_step(method_name, mock_run, "metrics/a", instances, extra) + mock_run.log.assert_called_once() + log_data = mock_run.log.call_args[0][0] + assert log_data["metrics/a"] != instances + assert all(isinstance(item, factory().__class__) for item in log_data["metrics/a"]) + + def test_step_forwarded(self, method_name, factory, extra): + """step 参数应被原样透传给 log""" + mock_run = _MockRun() + _call_step(method_name, mock_run, "k", factory(), extra, step=7) + _, kwargs = mock_run.log.call_args + assert kwargs.get("step") == 7 + + def test_step_none_by_default(self, method_name, factory, extra): + """不传 step 时,log 收到 step=None""" + mock_run = _MockRun() + _call_step(method_name, mock_run, "k", factory(), extra) + _, kwargs = mock_run.log.call_args + assert kwargs.get("step") is None + + def test_raises_when_not_running(self, method_name, factory, extra): + """run 未激活时应抛出 RuntimeError""" + mock_run = _MockRun() + mock_run._state = "finished" + with pytest.raises(RuntimeError, match="requires an active SwanLabRun"): + _call_step(method_name, mock_run, "k", factory(), extra) + + +def _call_caption(method_name, mock_run, key, data, extra, caption=None, step=None): + method = getattr(SwanLabRun, method_name) + return method(mock_run, key, data, **extra, caption=caption, step=step) + + +@pytest.mark.parametrize("method_name,factory,extra", _MEDIA_LOG_CASES) +class TestRunLogMediaCaption: + """caption 参数行为(对所有媒体类型一致)""" + + def test_caption_applied_to_raw_data(self, method_name, factory, extra): + """传入原始数据 + caption → normalize 后的对象 caption 正确""" + mock_run = _MockRun() + instance = factory() + _call_caption(method_name, mock_run, "k", instance, extra, caption="my caption") + # We can't inspect the caption easily without calling transform, so just + # verify log was called with a list + mock_run.log.assert_called_once() + log_data = mock_run.log.call_args[0][0] + assert isinstance(log_data["k"], list) + assert len(log_data["k"]) == 1 + assert log_data["k"][0].caption == "my caption" diff --git a/tests/unit/sdk/internal/run/test_decorators.py b/tests/unit/sdk/internal/run/test_run_decorators.py similarity index 100% rename from tests/unit/sdk/internal/run/test_decorators.py rename to tests/unit/sdk/internal/run/test_run_decorators.py