1- import json
21import logging
32from collections .abc import Iterable
43from dataclasses import dataclass , field
54from http import HTTPStatus
6- from typing import TYPE_CHECKING , Any , cast
5+ from typing import TYPE_CHECKING , Any
76
8- import aiohttp
97import aiohttp_cors
10- import aiotools
118import trafaret as t
129from aiohttp import web
1310
1411from ai .backend .common import validators as tx
1512from ai .backend .logging import BraceStyleAdapter
16- from ai .backend .manager .errors .common import (
17- InternalServerError ,
18- ObjectNotFound ,
19- ServerMisconfiguredError ,
20- )
13+ from ai .backend .manager .errors .common import ObjectNotFound
2114from ai .backend .manager .models .scaling_group import query_allowed_sgroups
2215from ai .backend .manager .models .utils import ExtendedAsyncSAEngine
2316
@@ -37,27 +30,6 @@ class WSProxyVersionQueryParams:
3730 db_ctx : ExtendedAsyncSAEngine = field (hash = False )
3831
3932
40- @aiotools .lru_cache (expire_after = 30 ) # expire after 30 seconds
41- async def query_wsproxy_status (
42- wsproxy_addr : str ,
43- ) -> dict [str , Any ]:
44- async with (
45- aiohttp .ClientSession () as session ,
46- session .get (
47- wsproxy_addr + "/status" ,
48- headers = {"Accept" : "application/json" },
49- ) as resp ,
50- ):
51- try :
52- result = await resp .json ()
53- except (aiohttp .ContentTypeError , json .JSONDecodeError ) as e :
54- log .error ("Failed to parse wsproxy status response from {}: {}" , wsproxy_addr , e )
55- raise InternalServerError (
56- "Got invalid response from wsproxy when querying status"
57- ) from e
58- return cast (dict [str , Any ], result )
59-
60-
6133@auth_required
6234@server_status_required (READ_ALLOWED )
6335@check_api_params (
@@ -100,29 +72,27 @@ async def get_wsproxy_version(request: web.Request, params: Any) -> web.Response
10072 domain_name = request ["user" ]["domain_name" ]
10173 group_id_or_name = params ["group" ]
10274 log .info ("SGROUPS.LIST(ak:{}, g:{}, d:{})" , access_key , group_id_or_name , domain_name )
75+ # remove appproxy client pool from root_ctx when db query migrated to service layer.
10376 async with root_ctx .db .begin_readonly () as conn :
10477 sgroups = await query_allowed_sgroups (conn , domain_name , group_id_or_name or "" , access_key )
105- for sgroup in sgroups :
106- if sgroup .name == scaling_group_name :
107- wsproxy_addr = sgroup .wsproxy_addr
108- if not wsproxy_addr :
109- wsproxy_version = "v1"
110- else :
111- try :
112- wsproxy_status = await query_wsproxy_status (wsproxy_addr )
113- wsproxy_version = wsproxy_status ["api_version" ]
114- except aiohttp .ClientConnectorError :
115- log .error (
116- "Failed to query the wsproxy {1} configured for sg:{0}" ,
117- scaling_group_name ,
118- wsproxy_addr ,
119- )
120- return ServerMisconfiguredError ()
121- return web .json_response ({
122- "wsproxy_version" : wsproxy_version ,
123- })
124- else :
78+ sgroup_filtered = [sg for sg in sgroups if sg .name == scaling_group_name ]
79+ if not sgroup_filtered :
12580 raise ObjectNotFound (object_name = "scaling group" )
81+ sgroup = sgroup_filtered [0 ]
82+
83+ if not sgroup .wsproxy_addr :
84+ # if wsproxy_addr is not set, raise not found error(migrating from v1 behavior)
85+ # It should be either 404 or 500 before wsproxy_addr is mandatory field.
86+ raise ObjectNotFound (object_name = "AppProxy address" )
87+ client = root_ctx .appproxy_client_pool .load_client (
88+ sgroup .wsproxy_addr , sgroup .wsproxy_api_token or ""
89+ )
90+ status = await client .fetch_status ()
91+ wsproxy_version = status .api_version
92+
93+ return web .json_response ({
94+ "wsproxy_version" : wsproxy_version ,
95+ })
12696
12797
12898async def init (app : web .Application ) -> None :
0 commit comments