Skip to content

Commit a8a49cd

Browse files
authored
refactor(BA-4186): move out of convention function into valid directory (#8509)
1 parent 95230db commit a8a49cd

13 files changed

Lines changed: 157 additions & 58 deletions

File tree

changes/8509.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Resolve cyclic import error from AppProxy with moving non-relevant function out

src/ai/backend/appproxy/coordinator/api/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ class StubResponseModel(BaseModel):
2525
success: Annotated[bool, Field(default=True)]
2626

2727

28+
class AppProxyStatusResponse(BaseModel):
29+
"""Response from AppProxy /status endpoint."""
30+
31+
api_version: str = Field(description="AppProxy API version (e.g., 'v1', 'v2')")
32+
advertise_address: str | None = Field(
33+
default=None, description="Advertised address for AppProxy"
34+
)
35+
36+
2837
class CircuitListResponseModel(BaseModel):
2938
circuits: list[SerializableCircuit]
3039

src/ai/backend/appproxy/coordinator/server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
mime_match,
7272
ping_redis_connection,
7373
)
74+
from ai.backend.appproxy.coordinator.api.types import AppProxyStatusResponse
7475
from ai.backend.appproxy.coordinator.models.worker import WorkerStatus
7576
from ai.backend.common import redis_helper
7677
from ai.backend.common.clients.valkey_client.valkey_leader.client import ValkeyLeaderClient
@@ -845,10 +846,11 @@ async def status(request: web.Request) -> web.Response:
845846
root_ctx: RootContext = request.app["_root.context"]
846847
request["do_not_print_access_log"] = True
847848
advertised_addr = root_ctx.local_config.proxy_coordinator.advertise_base_url
848-
return web.json_response({
849-
"api_version": "v2",
850-
"advertise_address": advertised_addr,
851-
})
849+
response = AppProxyStatusResponse(
850+
api_version="v2",
851+
advertise_address=advertised_addr,
852+
)
853+
return web.json_response(response.model_dump(mode="json"))
852854

853855

854856
def handle_loop_error(

src/ai/backend/common/exception.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class ErrorDomain(enum.StrEnum):
173173
ROLE = "role"
174174
METRIC = "metric"
175175
STORAGE_PROXY = "storage-proxy"
176+
APPPROXY = "appproxy"
176177
MESSAGE_QUEUE = "message-queue"
177178
NOTIFICATION = "notification"
178179
HEALTH_CHECK = "health-check"

src/ai/backend/manager/api/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ai.backend.common.types import ValkeyProfileTarget
4141
from ai.backend.manager.agent_cache import AgentRPCCache
4242
from ai.backend.manager.clients.agent import AgentClientPool
43+
from ai.backend.manager.clients.appproxy.client import AppProxyClientPool
4344
from ai.backend.manager.config.provider import ManagerConfigProvider
4445
from ai.backend.manager.idle import IdleCheckerHost
4546
from ai.backend.manager.models.storage import StorageSessionManager
@@ -97,6 +98,7 @@ class RootContext(BaseContext):
9798
registry: AgentRegistry
9899
agent_cache: AgentRPCCache
99100
agent_client_pool: AgentClientPool
101+
appproxy_client_pool: AppProxyClientPool
100102
sokovan_orchestrator: SokovanOrchestrator
101103
scheduling_controller: SchedulingController
102104
deployment_controller: DeploymentController

src/ai/backend/manager/api/scaling_group.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
1-
import json
21
import logging
32
from collections.abc import Iterable
43
from dataclasses import dataclass, field
54
from http import HTTPStatus
6-
from typing import TYPE_CHECKING, Any, cast
5+
from typing import TYPE_CHECKING, Any
76

8-
import aiohttp
97
import aiohttp_cors
10-
import aiotools
118
import trafaret as t
129
from aiohttp import web
1310

1411
from ai.backend.common import validators as tx
1512
from ai.backend.logging import BraceStyleAdapter
16-
from ai.backend.manager.errors.common import (
17-
InternalServerError,
18-
ObjectNotFound,
19-
ServerMisconfiguredError,
20-
)
13+
from ai.backend.manager.errors.common import ObjectNotFound
2114
from ai.backend.manager.models.scaling_group import query_allowed_sgroups
2215
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
2316

@@ -37,27 +30,6 @@ class WSProxyVersionQueryParams:
3730
db_ctx: ExtendedAsyncSAEngine = field(hash=False)
3831

3932

40-
@aiotools.lru_cache(expire_after=30) # expire after 30 seconds
41-
async def query_wsproxy_status(
42-
wsproxy_addr: str,
43-
) -> dict[str, Any]:
44-
async with (
45-
aiohttp.ClientSession() as session,
46-
session.get(
47-
wsproxy_addr + "/status",
48-
headers={"Accept": "application/json"},
49-
) as resp,
50-
):
51-
try:
52-
result = await resp.json()
53-
except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
54-
log.error("Failed to parse wsproxy status response from {}: {}", wsproxy_addr, e)
55-
raise InternalServerError(
56-
"Got invalid response from wsproxy when querying status"
57-
) from e
58-
return cast(dict[str, Any], result)
59-
60-
6133
@auth_required
6234
@server_status_required(READ_ALLOWED)
6335
@check_api_params(
@@ -100,29 +72,27 @@ async def get_wsproxy_version(request: web.Request, params: Any) -> web.Response
10072
domain_name = request["user"]["domain_name"]
10173
group_id_or_name = params["group"]
10274
log.info("SGROUPS.LIST(ak:{}, g:{}, d:{})", access_key, group_id_or_name, domain_name)
75+
# remove appproxy client pool from root_ctx when db query migrated to service layer.
10376
async with root_ctx.db.begin_readonly() as conn:
10477
sgroups = await query_allowed_sgroups(conn, domain_name, group_id_or_name or "", access_key)
105-
for sgroup in sgroups:
106-
if sgroup.name == scaling_group_name:
107-
wsproxy_addr = sgroup.wsproxy_addr
108-
if not wsproxy_addr:
109-
wsproxy_version = "v1"
110-
else:
111-
try:
112-
wsproxy_status = await query_wsproxy_status(wsproxy_addr)
113-
wsproxy_version = wsproxy_status["api_version"]
114-
except aiohttp.ClientConnectorError:
115-
log.error(
116-
"Failed to query the wsproxy {1} configured for sg:{0}",
117-
scaling_group_name,
118-
wsproxy_addr,
119-
)
120-
return ServerMisconfiguredError()
121-
return web.json_response({
122-
"wsproxy_version": wsproxy_version,
123-
})
124-
else:
78+
sgroup_filtered = [sg for sg in sgroups if sg.name == scaling_group_name]
79+
if not sgroup_filtered:
12580
raise ObjectNotFound(object_name="scaling group")
81+
sgroup = sgroup_filtered[0]
82+
83+
if not sgroup.wsproxy_addr:
84+
# if wsproxy_addr is not set, raise not found error(migrating from v1 behavior)
85+
# It should be either 404 or 500 before wsproxy_addr is mandatory field.
86+
raise ObjectNotFound(object_name="AppProxy address")
87+
client = root_ctx.appproxy_client_pool.load_client(
88+
sgroup.wsproxy_addr, sgroup.wsproxy_api_token or ""
89+
)
90+
status = await client.fetch_status()
91+
wsproxy_version = status.api_version
92+
93+
return web.json_response({
94+
"wsproxy_version": wsproxy_version,
95+
})
12696

12797

12898
async def init(app: web.Application) -> None:

src/ai/backend/manager/clients/appproxy/client.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import logging
15
from typing import Any
26
from uuid import UUID
37

48
import aiohttp
59

10+
from ai.backend.appproxy.coordinator.api.types import AppProxyStatusResponse
11+
from ai.backend.common.clients.http_client.client_pool import (
12+
ClientKey,
13+
ClientPool,
14+
tcp_client_session_factory,
15+
)
616
from ai.backend.common.exception import BackendAIError
717
from ai.backend.common.metrics.metric import DomainType, LayerType
818
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
919
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
1020
from ai.backend.common.resilience.resilience import Resilience
21+
from ai.backend.logging import BraceStyleAdapter
22+
from ai.backend.manager.errors.appproxy import AppProxyConnectionError, AppProxyResponseError
1123

1224
from .types import CreateEndpointRequestBody
1325

26+
log: BraceStyleAdapter = BraceStyleAdapter(logging.getLogger(__spec__.name))
27+
1428
appproxy_client_resilience = Resilience(
1529
policies=[
1630
MetricPolicy(MetricArgs(domain=DomainType.CLIENT, layer=LayerType.WSPROXY_CLIENT)),
@@ -26,6 +40,25 @@
2640
)
2741

2842

43+
class AppProxyClientPool:
44+
_client_pool: ClientPool
45+
46+
def __init__(self) -> None:
47+
self._client_pool = ClientPool(tcp_client_session_factory)
48+
49+
def load_client(self, address: str, token: str) -> AppProxyClient:
50+
client_session = self._client_pool.load_client_session(
51+
ClientKey(
52+
endpoint=address,
53+
domain="appproxy",
54+
)
55+
)
56+
return AppProxyClient(client_session, address, token)
57+
58+
async def close(self) -> None:
59+
await self._client_pool.close()
60+
61+
2962
class AppProxyClient:
3063
_client_session: aiohttp.ClientSession
3164
_address: str
@@ -36,6 +69,26 @@ def __init__(self, client_session: aiohttp.ClientSession, address: str, token: s
3669
self._address = address
3770
self._token = token
3871

72+
@appproxy_client_resilience.apply()
73+
async def fetch_status(self) -> AppProxyStatusResponse:
74+
try:
75+
async with self._client_session.get(
76+
"/status",
77+
headers={"Accept": "application/json"},
78+
) as resp:
79+
data = await resp.json()
80+
return AppProxyStatusResponse.model_validate(data)
81+
except aiohttp.ClientConnectorError as e:
82+
log.error("Failed to connect to app-proxy at {}: {}", self._address, e)
83+
raise AppProxyConnectionError(
84+
extra_msg=f"Failed to connect to AppProxy at {self._address}"
85+
) from e
86+
except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
87+
log.error("Failed to parse app-proxy status response from {}: {}", self._address, e)
88+
raise AppProxyResponseError(
89+
extra_msg=f"Invalid response from AppProxy at {self._address}"
90+
) from e
91+
3992
@appproxy_client_resilience.apply()
4093
async def create_endpoint(
4194
self,
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
from aiohttp import web
4+
5+
from ai.backend.common.exception import (
6+
BackendAIError,
7+
ErrorCode,
8+
ErrorDetail,
9+
ErrorDomain,
10+
ErrorOperation,
11+
)
12+
13+
14+
class AppProxyConnectionError(BackendAIError, web.HTTPServiceUnavailable):
15+
"""Raised when connection to AppProxy fails."""
16+
17+
error_type = "https://api.backend.ai/probs/appproxy-connection-error"
18+
error_title = "Failed to connect to AppProxy."
19+
20+
def error_code(self) -> ErrorCode:
21+
return ErrorCode(
22+
domain=ErrorDomain.APPPROXY,
23+
operation=ErrorOperation.REQUEST,
24+
error_detail=ErrorDetail.UNAVAILABLE,
25+
)
26+
27+
28+
class AppProxyResponseError(BackendAIError, web.HTTPInternalServerError):
29+
"""Raised when AppProxy returns an invalid response."""
30+
31+
error_type = "https://api.backend.ai/probs/appproxy-response-error"
32+
error_title = "Invalid response from AppProxy."
33+
34+
def error_code(self) -> ErrorCode:
35+
return ErrorCode(
36+
domain=ErrorDomain.APPPROXY,
37+
operation=ErrorOperation.REQUEST,
38+
error_detail=ErrorDetail.INVALID_DATA_FORMAT,
39+
)

src/ai/backend/manager/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ class AgentRegistry:
204204
pending_waits: set[asyncio.Task[None]]
205205
database_ptask_group: aiotools.PersistentTaskGroup
206206
webhook_ptask_group: aiotools.PersistentTaskGroup
207+
# TODO: Migrate to use root_ctx.appproxy_client_pool instead.
208+
# After migration, remove _client_pool and _load_app_proxy_client.
207209
_client_pool: ClientPool
208210
_agent_client_pool: AgentClientPool
209211

@@ -268,6 +270,7 @@ async def shutdown(self) -> None:
268270
await self.database_ptask_group.shutdown()
269271
await self.webhook_ptask_group.shutdown()
270272

273+
# TODO: Migrate callers to use root_ctx.appproxy_client_pool.load_client() instead.
271274
def _load_app_proxy_client(self, address: str, token: str) -> AppProxyClient:
272275
client_session = self._client_pool.load_client_session(
273276
ClientKey(

src/ai/backend/manager/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
from .api import ManagerStatus
147147
from .api.context import RootContext
148148
from .clients.agent import AgentClientPool, AgentPoolSpec
149+
from .clients.appproxy.client import AppProxyClientPool
149150
from .config.bootstrap import BootstrapConfig
150151
from .config.loader.config_overrider import ConfigOverrider
151152
from .config.loader.etcd_loader import (
@@ -762,6 +763,7 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
762763
event_producer=root_ctx.event_producer,
763764
agent_cache=root_ctx.agent_cache,
764765
notification_center=root_ctx.notification_center,
766+
appproxy_client_pool=root_ctx.appproxy_client_pool,
765767
),
766768
),
767769
[reporter_monitor, prometheus_monitor, audit_log_monitor],
@@ -1103,6 +1105,7 @@ async def agent_registry_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
11031105
recovery_timeout=60.0,
11041106
),
11051107
)
1108+
root_ctx.appproxy_client_pool = AppProxyClientPool()
11061109
root_ctx.registry = AgentRegistry(
11071110
root_ctx.config_provider,
11081111
root_ctx.db,
@@ -1126,6 +1129,7 @@ async def agent_registry_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
11261129
yield
11271130
finally:
11281131
await root_ctx.agent_client_pool.close()
1132+
await root_ctx.appproxy_client_pool.close()
11291133
await root_ctx.registry.shutdown()
11301134

11311135

0 commit comments

Comments
 (0)