Skip to content

Commit cc43b23

Browse files
authored
Merge pull request #18 from Serverless-Devs/fix-integration-specific-model
feat(model): support specific model parameter in CommonModel
2 parents f377b0d + cd79e94 commit cc43b23

File tree

3 files changed

+17
-19
lines changed

3 files changed

+17
-19
lines changed

agentrun/integration/builtin/model.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,29 +69,25 @@ def model(
6969
backend_type = kwargs.get("backend_type")
7070
model = kwargs.get("model")
7171

72+
if isinstance(input, str):
73+
from agentrun.model.client import ModelClient
74+
75+
client = ModelClient(config=config)
76+
input = client.get(name=input, backend_type=backend_type, config=config)
77+
7278
if isinstance(input, ModelProxy):
7379
return CommonModel(
74-
model=model or "",
7580
model_obj=input,
7681
backend_type=BackendType.PROXY,
82+
specific_model=model,
7783
config=config,
7884
)
7985
elif isinstance(input, ModelService):
8086
return CommonModel(
81-
model=model or "",
8287
model_obj=input,
8388
backend_type=BackendType.SERVICE,
89+
specific_model=model,
8490
config=config,
8591
)
86-
87-
from agentrun.model.client import ModelClient
88-
89-
client = ModelClient(config=config)
90-
model_obj = client.get(name=input, backend_type=backend_type, config=config)
91-
92-
return CommonModel(
93-
model=input,
94-
model_obj=model_obj,
95-
backend_type=backend_type,
96-
config=config,
97-
)
92+
else:
93+
raise TypeError("input must be str, ModelProxy or ModelService")

agentrun/integration/utils/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ class CommonModel:
1919

2020
def __init__(
2121
self,
22-
model: Optional[str],
2322
model_obj: Union[ModelService, ModelProxy],
2423
backend_type: Optional[BackendType] = None,
24+
specific_model: Optional[str] = None,
2525
config: Optional[Config] = None,
2626
):
27-
self.model = model
2827
self.model_obj = model_obj
2928
self.backend_type = backend_type
29+
self.specific_model = specific_model
3030
self.config = config or Config()
3131

3232
def completions(self, *args, **kwargs):
@@ -40,7 +40,10 @@ def responses(self, *args, **kwargs):
4040
def get_model_info(self, config: Optional[Config] = None):
4141
"""获取模型信息"""
4242
cfg = Config.with_configs(self.config, config)
43-
return self.model_obj.model_info(config=cfg)
43+
info = self.model_obj.model_info(config=cfg)
44+
if self.specific_model:
45+
info.model = self.specific_model
46+
return info
4447

4548
def __convert_model(self, adapter_name: str):
4649
try:

tests/unittests/integration/test_integration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def extract_payload(request):
8686
def build_response(request, route):
8787
payload = extract_payload(request)
8888
is_stream = payload.get("stream", False)
89+
assert payload.get("model") == "mock-model-proxy"
8990
response_json = self._build_response(
9091
payload.get("messages") or [], payload.get("tools")
9192
)
@@ -355,8 +356,6 @@ def get_mocked_model(
355356
)
356357
m = model("fake-llm-model")
357358

358-
assert m.model == "fake-llm-model"
359-
360359
return m
361360

362361
def test_langchain(self, monkeypatch, mock_llm_transport):

0 commit comments

Comments
 (0)