diff --git a/changes/8513.enhance.md b/changes/8513.enhance.md new file mode 100644 index 00000000000..fadd3ab1fe9 --- /dev/null +++ b/changes/8513.enhance.md @@ -0,0 +1 @@ +migrate unnecessary service related function from api to service \ No newline at end of file diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index cacb75c4ed3..7274ad774bb 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -118,7 +118,6 @@ from .utils import ( LegacyBaseRequestModel, LegacyBaseResponseModel, - Undefined, catch_unexpected, check_api_params, deprecated_stub, @@ -313,49 +312,6 @@ def check_and_return(self, value: Any) -> object: tx.AliasedKey(["attach_network", "attachNetwork"], default=None): t.Null | tx.UUID, }) -overwritten_param_check = t.Dict({ - t.Key("template_id"): tx.UUID, - t.Key("session_name"): tx.SessionName, - t.Key("image", default=None): t.Null | t.String, - tx.AliasedKey(["session_type", "sess_type"]): tx.Enum(SessionTypes), - t.Key("group", default=None): t.Null | t.String, - t.Key("domain", default=None): t.Null | t.String, - t.Key("config", default=None): t.Null | t.Mapping(t.String, t.Any), - t.Key("tag", default=None): t.Null | t.String, - t.Key("enqueue_only", default=False): t.ToBool, - t.Key("max_wait_seconds", default=0): t.Int[0:], - t.Key("reuse", default=True): t.ToBool, - t.Key("startup_command", default=None): t.Null | t.String, - t.Key("bootstrap_script", default=None): t.Null | t.String, - t.Key("owner_access_key", default=None): t.Null | t.String, - tx.AliasedKey(["scaling_group", "scalingGroup"], default=None): t.Null | t.String, - tx.AliasedKey(["cluster_size", "clusterSize"], default=None): t.Null | t.Int[1:], - tx.AliasedKey(["cluster_mode", "clusterMode"], default="SINGLE_NODE"): tx.Enum(ClusterMode), - tx.AliasedKey(["starts_at", "startsAt"], default=None): t.Null | t.String, - tx.AliasedKey(["batch_timeout", "batchTimeout"], default=None): t.Null | tx.TimeDuration, -}).allow_extra("*") - - -def sub(d: dict[Any, Any], old: Any, new: Any) -> dict[Any, Any]: - for k, v in d.items(): - if isinstance(v, (Mapping, dict)): - d[k] = sub(dict(v), old, new) - elif d[k] == old: - d[k] = new - return d - - -def drop_undefined(d: dict[Any, Any]) -> dict[Any, Any]: - newd: dict[Any, Any] = {} - for k, v in d.items(): - if isinstance(v, (Mapping, dict)): - newval = drop_undefined(dict(v)) - if len(newval.keys()) > 0: # exclude empty dict always - newd[k] = newval - elif not isinstance(v, Undefined): - newd[k] = v - return newd - async def query_userinfo( request: web.Request, diff --git a/src/ai/backend/manager/services/session/service.py b/src/ai/backend/manager/services/session/service.py index cc8cdcbe47f..6b9c531ee5f 100644 --- a/src/ai/backend/manager/services/session/service.py +++ b/src/ai/backend/manager/services/session/service.py @@ -36,10 +36,6 @@ ) from ai.backend.logging.utils import BraceStyleAdapter from ai.backend.manager.api.scaling_group import query_wsproxy_status -from ai.backend.manager.api.session import ( - drop_undefined, - overwritten_param_check, -) from ai.backend.manager.api.utils import undefined from ai.backend.manager.bgtask.tasks.commit_session import CommitSessionManifest from ai.backend.manager.bgtask.types import ManagerBgtaskName @@ -193,7 +189,12 @@ UploadFilesAction, UploadFilesActionResult, ) -from ai.backend.manager.services.session.types import CommitStatusInfo, LegacySessionInfo +from ai.backend.manager.services.session.types import ( + CommitStatusInfo, + LegacySessionInfo, + overwritten_param_check, +) +from ai.backend.manager.services.session.utils import drop_undefined from ai.backend.manager.sokovan.scheduling_controller import SchedulingController from ai.backend.manager.types import UserScope diff --git a/src/ai/backend/manager/services/session/types.py b/src/ai/backend/manager/services/session/types.py index b06d5bfbf8a..6b394d54c05 100644 --- a/src/ai/backend/manager/services/session/types.py +++ b/src/ai/backend/manager/services/session/types.py @@ -1,11 +1,39 @@ +from __future__ import annotations + import dataclasses from dataclasses import dataclass, field from datetime import datetime from typing import Any, Optional from uuid import UUID +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.types import ClusterMode, SessionTypes from ai.backend.manager.data.session.types import SessionStatus +overwritten_param_check = t.Dict({ + t.Key("template_id"): tx.UUID, + t.Key("session_name"): tx.SessionName, + t.Key("image", default=None): t.Null | t.String, + tx.AliasedKey(["session_type", "sess_type"]): tx.Enum(SessionTypes), + t.Key("group", default=None): t.Null | t.String, + t.Key("domain", default=None): t.Null | t.String, + t.Key("config", default=None): t.Null | t.Mapping(t.String, t.Any), + t.Key("tag", default=None): t.Null | t.String, + t.Key("enqueue_only", default=False): t.ToBool, + t.Key("max_wait_seconds", default=0): t.Int[0:], + t.Key("reuse", default=True): t.ToBool, + t.Key("startup_command", default=None): t.Null | t.String, + t.Key("bootstrap_script", default=None): t.Null | t.String, + t.Key("owner_access_key", default=None): t.Null | t.String, + tx.AliasedKey(["scaling_group", "scalingGroup"], default=None): t.Null | t.String, + tx.AliasedKey(["cluster_size", "clusterSize"], default=None): t.Null | t.Int[1:], + tx.AliasedKey(["cluster_mode", "clusterMode"], default="SINGLE_NODE"): tx.Enum(ClusterMode), + tx.AliasedKey(["starts_at", "startsAt"], default=None): t.Null | t.String, + tx.AliasedKey(["batch_timeout", "batchTimeout"], default=None): t.Null | tx.TimeDuration, +}).allow_extra("*") + @dataclass class LegacySessionInfo: diff --git a/src/ai/backend/manager/services/session/utils.py b/src/ai/backend/manager/services/session/utils.py new file mode 100644 index 00000000000..c940ab3cf55 --- /dev/null +++ b/src/ai/backend/manager/services/session/utils.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from ai.backend.manager.api.utils import Undefined + + +def drop_undefined(d: dict[Any, Any]) -> dict[Any, Any]: + newd: dict[Any, Any] = {} + for k, v in d.items(): + if isinstance(v, (Mapping, dict)): + newval = drop_undefined(dict(v)) + if len(newval.keys()) > 0: # exclude empty dict always + newd[k] = newval + elif not isinstance(v, Undefined): + newd[k] = v + return newd