Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"merge_settings",
"init",
"log",
"save",
"register_callbacks",
"finish",
"Audio",
Expand Down
102 changes: 88 additions & 14 deletions swanlab/core_python/api/experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
@description: 定义实验相关的后端API接口
"""

from typing import Literal, Dict, TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Union

from swanlab.core_python.api.type import RunType
from .utils import to_camel_case, parse_column_type

from .utils import (
parse_column_type,
to_camel_case,
unwrap_api_payload,
)

if TYPE_CHECKING:
from swanlab.core_python.client import Client
Expand All @@ -35,7 +40,7 @@ def update_experiment_state(
username: str,
projname: str,
cuid: str,
state: Literal['FINISHED', 'CRASHED', 'ABORTED'],
state: Literal["FINISHED", "CRASHED", "ABORTED"],
finished_at: str = None,
):
"""
Expand Down Expand Up @@ -76,7 +81,7 @@ def get_project_experiments(
- 'job_type': 按任务类型筛选,值为字符串
"""
# 特殊筛选条件配置:用户侧 key -> 后端 key 和操作符
SPECIAL_FILTER_CONFIG = {
special_filter_config = {
"group": {"key": "cluster", "op": "EQ"},
"tags": {"key": "labels", "op": "IN"},
"name": {"key": "name", "op": "EQ"},
Expand All @@ -88,33 +93,39 @@ def get_project_experiments(

if filters:
for key, value in filters.items():
if key in SPECIAL_FILTER_CONFIG:
if key in special_filter_config:
# 特殊字段处理
config = SPECIAL_FILTER_CONFIG[key]
config = special_filter_config[key]
# tags 需要转换为列表
filter_value = list(value) if key == "tags" and isinstance(value, (list, tuple)) else [value]
filter_value = (
list(value)
if key == "tags" and isinstance(value, (list, tuple))
else [value]
)
parsed_filters.append(
{
"key": config["key"],
"active": True,
"value": filter_value,
"op": config["op"],
"type": 'STABLE',
"type": "STABLE",
}
)
else:
# 常规字段处理
parsed_filters.append(
{
"key": to_camel_case(key) if parse_column_type(key) == 'STABLE' else key.split('.', 1)[-1],
"key": to_camel_case(key)
if parse_column_type(key) == "STABLE"
else key.split(".", 1)[-1],
"active": True,
"value": [value],
"op": 'EQ',
"op": "EQ",
"type": parse_column_type(key),
}
)

res = client.post(f"/project/{path}/runs/shows", data={'filters': parsed_filters})
res = client.post(f"/project/{path}/runs/shows", data={"filters": parsed_filters})
return res[0]


Expand All @@ -125,7 +136,7 @@ def get_single_experiment(client: "Client", *, path: str) -> RunType:
:param client: 已登录的客户端实例
:param path: 实验路径 username/project/expid
"""
proj_path, expid = path.rsplit('/', 1)
proj_path, expid = path.rsplit("/", 1)
res = client.get(f"/project/{proj_path}/runs/{expid}")
return res[0]

Expand All @@ -137,7 +148,7 @@ def get_experiment_metrics(client: "Client", *, expid: str, key: str) -> Dict[st
:param expid: 实验cuid
:param key: 指定字段列表
"""
res = client.get(f"/experiment/{expid}/column/csv", params={'key': key})
res = client.get(f"/experiment/{expid}/column/csv", params={"key": key})
return res[0]


Expand All @@ -147,15 +158,78 @@ def delete_experiment(client: "Client", *, path: str):
:param client: 已登录的客户端实例
:param path: 实验路径 'username/project/expid'
"""
proj_path, expid = path.rsplit('/', 1)
proj_path, expid = path.rsplit("/", 1)
client.delete(f"/project/{proj_path}/runs/{expid}")


def prepare_upload(
client: "Client", exp_id: str, files: Iterable[Dict[str, object]]
) -> List[str]:
"""
创建普通文件上传任务,返回预签名上传地址列表。
"""
payload_files = list(files)
if not payload_files:
return []
data, _ = client.post(
f"/experiment/{exp_id}/files/prepare", {"files": payload_files}
)
result = unwrap_api_payload(data)
if isinstance(result, dict):
urls = result.get("urls", [])
return urls if isinstance(urls, list) else []
return []


def complete_upload(
client: "Client", exp_id: str, files: Iterable[Dict[str, object]]
) -> None:
"""
标记普通文件上传完成。
"""
payload_files = list(files)
if not payload_files:
return
client.post(f"/experiment/{exp_id}/files/complete", {"files": payload_files})


def prepare_multipart(
client: "Client", exp_id: str, file: Dict[str, object]
) -> Dict[str, object]:
"""
创建分片上传任务,返回 uploadId 和分片上传地址列表。
"""
data, _ = client.post(
f"/experiment/{exp_id}/files/prepare-multipart",
{"files": [file]},
)
result = unwrap_api_payload(data)
if isinstance(result, dict):
files = result.get("files", [])
if files and isinstance(files, list):
return files[0]
raise ValueError("Multipart prepare API returned empty file payloads.")


def complete_multipart(client: "Client", exp_id: str, file: Dict[str, object]) -> None:
"""
标记分片上传完成。
"""
client.post(
f"/experiment/{exp_id}/files/complete-multipart",
{"files": [file]},
)


__all__ = [
"send_experiment_heartbeat",
"update_experiment_state",
"get_project_experiments",
"get_single_experiment",
"get_experiment_metrics",
"delete_experiment",
"prepare_upload",
"complete_upload",
"prepare_multipart",
"complete_multipart",
]
63 changes: 56 additions & 7 deletions swanlab/core_python/api/experiment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,69 @@
@description: 实验相关的后端API接口中的工具函数
"""

from typing import Dict, List, Optional, Tuple

from swanlab.core_python.api.type import ColumnType


# 从前缀中获取指标类型
def parse_column_type(column: str) -> ColumnType:
column_type = column.split('.', 1)[0]
if column_type == 'summary':
return 'SCALAR'
elif column_type == 'config':
return 'CONFIG'
column_type = column.split(".", 1)[0]
if column_type == "summary":
return "SCALAR"
elif column_type == "config":
return "CONFIG"
else:
return 'STABLE'
return "STABLE"


# 将下划线命名转化为驼峰命名
def to_camel_case(name: str) -> str:
return ''.join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split('_'))])
return "".join(
[w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]
)


def unwrap_api_payload(data):
Comment thread
Nexisato marked this conversation as resolved.
if (
isinstance(data, dict)
and "data" in data
and isinstance(data["data"], (dict, list))
):
return data["data"]
return data


def extract_upload_id(payload: Dict[str, object]) -> Optional[str]:
upload_id = payload.get("uploadId")
if isinstance(upload_id, str) and upload_id:
return upload_id
return None



def extract_part_urls(payload: Dict[str, object]) -> List[Tuple[int, str]]:
parts = payload.get("parts")
if not isinstance(parts, list):
raise ValueError("Multipart upload URLs are missing in prepare response.")

resolved = []
for part in parts:
if not isinstance(part, dict):
raise ValueError("Multipart prepare response contains invalid part data.")
part_number = part.get("partNumber")
url = part.get("url")
if not isinstance(part_number, int) or not isinstance(url, str) or not url:
raise ValueError("Invalid partNumber or url in multipart response.")
resolved.append((part_number, url))

return sorted(resolved, key=lambda item: item[0])


__all__ = [
"parse_column_type",
"to_camel_case",
"unwrap_api_payload",
"extract_upload_id",
"extract_part_urls",
]
19 changes: 12 additions & 7 deletions swanlab/core_python/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import List, Tuple
from typing import List, Optional, Tuple

import requests
from requests.exceptions import RequestException
Expand All @@ -18,7 +18,9 @@
from ...toolkit.models.data import MediaBuffer


def upload_file(*, url: str, buffer: BytesIO, max_retries=3):
MIME_TYPE_DEFAULT: str = "application/octet-stream"

def upload_file(*, url: str, buffer: BytesIO, max_retries=3, mime_type: str=MIME_TYPE_DEFAULT) -> Optional[str]:
"""
上传文件到COS
:param url: COS上传URL
Expand All @@ -33,13 +35,16 @@ def upload_file(*, url: str, buffer: BytesIO, max_retries=3):
response = session.put(
url,
data=buffer,
headers={'Content-Type': 'application/octet-stream'},
headers={"Content-Type": mime_type},
timeout=30,
)
response.raise_for_status()
return
etag = response.headers.get("ETag")
return etag if etag else None
except RequestException:
swanlog.warning("Upload attempt {} failed for URL: {}".format(attempt, url))
swanlog.warning(
"Upload attempt {} failed for URL: {}".format(attempt, url)
)
# 如果是最后一次尝试,抛出异常
if attempt == max_retries:
raise
Expand All @@ -57,10 +62,10 @@ def upload_to_cos(client: Client, *, cuid: str, buffers: List[MediaBuffer]):
failed_buffers: List[Tuple[str, MediaBuffer]] = []
# 1. 后端签名
data, _ = client.post(
'/resources/presigned/put',
"/resources/presigned/put",
{"experimentId": cuid, "paths": [buffer.file_name for buffer in buffers]},
)
urls: List[str] = data['urls']
urls: List[str] = data["urls"]
# 2. 并发上传
# executor.submit可能会失败,因为线程数有限或者线程池已经关闭
# 来自此issue: https://github.com/SwanHubX/SwanLab/issues/889,此时需要一个个发送
Expand Down
21 changes: 21 additions & 0 deletions swanlab/core_python/save/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .manager import DirWatcher, FileUploadManager
from .model import SaveFileState, SaveFileModel
from .utils import (
collect_save_files,
compute_md5,
file_signature,
guess_mime_type,
validate_glob_path,
)

__all__ = [
"SaveFileState",
"SaveFileModel",
"collect_save_files",
"validate_glob_path",
"compute_md5",
"guess_mime_type",
"file_signature",
"FileUploadManager",
"DirWatcher",
]
Loading
Loading