Skip to content

Commit f660b19

Browse files
committed
feat: support base_url for Cohere rerankers
Add optional `base_url` to `CohereRerankerConf` and forward it when building `cohere.ClientV2`, allowing MemMachine to use self-hosted Cohere-compatible rerank APIs without adding a new provider. Also add unit tests for config parsing/validation and client kwargs forwarding, and update the three episodic sample configs with commented `base_url` examples. Reported-by: Kwangjin Ko <kwangjin.ko@sk.com> Signed-off-by: Hyeongtak Ji <hyeongtak.ji@sk.com>
1 parent bed055d commit f660b19

7 files changed

Lines changed: 130 additions & 3 deletions

File tree

packages/server/server_tests/memmachine_server/common/configuration/test_reranker_conf.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from memmachine_server.common.configuration.reranker_conf import (
88
AmazonBedrockRerankerConf,
99
BM25RerankerConf,
10+
CohereRerankerConf,
1011
CrossEncoderRerankerConf,
1112
EmbedderRerankerConf,
1213
IdentityRerankerConf,
@@ -47,6 +48,18 @@ def amazon_bedrock_reranker_conf() -> dict:
4748
}
4849

4950

51+
@pytest.fixture
52+
def cohere_reranker_conf() -> dict:
53+
return {
54+
"provider": "cohere",
55+
"config": {
56+
"cohere_key": "test-cohere-key",
57+
"model": "rerank-english-v3.0",
58+
"base_url": "http://localhost:8000",
59+
},
60+
}
61+
62+
5063
@pytest.fixture
5164
def cross_encoder_reranker_conf() -> dict:
5265
return {
@@ -83,6 +96,7 @@ def full_reranker_input(
8396
bm25_reranker_conf,
8497
identity_reranker_conf,
8598
amazon_bedrock_reranker_conf,
99+
cohere_reranker_conf,
86100
cross_encoder_reranker_conf,
87101
embedder_reranker_conf,
88102
rrf_hybrid_reranker_conf,
@@ -93,6 +107,7 @@ def full_reranker_input(
93107
"id_ranker_id": identity_reranker_conf,
94108
"bm_ranker_id": bm25_reranker_conf,
95109
"aws_reranker_id": amazon_bedrock_reranker_conf,
110+
"cohere_reranker_id": cohere_reranker_conf,
96111
"cross_encoder_id": cross_encoder_reranker_conf,
97112
"embedder_id": embedder_reranker_conf,
98113
},
@@ -121,6 +136,13 @@ def test_valid_amazon_bedrock_reranker_conf(amazon_bedrock_reranker_conf):
121136
assert conf.model_id == "amazon.rerank-v1:0"
122137

123138

139+
def test_valid_cohere_reranker_conf(cohere_reranker_conf):
140+
conf = CohereRerankerConf(**cohere_reranker_conf["config"])
141+
assert conf.cohere_key == SecretStr("test-cohere-key")
142+
assert conf.model == "rerank-english-v3.0"
143+
assert conf.base_url == "http://localhost:8000"
144+
145+
124146
def test_valid_cross_encoder_reranker_conf(cross_encoder_reranker_conf):
125147
conf = CrossEncoderRerankerConf(**cross_encoder_reranker_conf["config"])
126148
assert conf.model_name == "cross-encoder/qnli-electra-base"
@@ -150,6 +172,8 @@ def test_full_reranker_conf(full_reranker_input):
150172

151173
assert "aws_reranker_id" in conf.amazon_bedrock
152174
assert conf.amazon_bedrock["aws_reranker_id"].region == "us-west-2"
175+
assert "cohere_reranker_id" in conf.cohere
176+
assert conf.cohere["cohere_reranker_id"].base_url == "http://localhost:8000"
153177

154178
assert "cross_encoder_id" in conf.cross_encoder
155179
assert (
@@ -182,3 +206,12 @@ def test_missing_required_field_in_bedrock_reranker():
182206
AmazonBedrockRerankerConf(**config)
183207
assert "missing" in str(exc_info.value)
184208
assert "aws_secret_access_key" in str(exc_info.value)
209+
210+
211+
def test_invalid_cohere_base_url():
212+
with pytest.raises(ValidationError, match="Invalid base URL"):
213+
CohereRerankerConf(
214+
cohere_key=SecretStr("test-cohere-key"),
215+
model="rerank-english-v3.0",
216+
base_url="localhost:8000",
217+
)

packages/server/server_tests/memmachine_server/common/resource_manager/test_reranker_manager.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def mock_conf():
3434
bm25={"bm_ranker_id": BM25RerankerConf(tokenizer="simple")},
3535
cohere={
3636
"cohere_reranker_id": CohereRerankerConf(
37-
cohere_key=SecretStr("<COHERE_API_KEY>"),
37+
cohere_key=SecretStr("test-cohere-key"),
3838
model="rerank-english-v3.0",
39+
base_url="http://localhost:8000",
3940
),
4041
},
4142
cross_encoder={
@@ -132,6 +133,72 @@ async def test_build_cohere_rerankers(reranker_manager):
132133
assert reranker is not None
133134

134135

136+
@pytest.mark.asyncio
137+
async def test_build_cohere_reranker_passes_base_url(monkeypatch):
138+
captured_kwargs = {}
139+
140+
class FakeCohereClient:
141+
pass
142+
143+
def fake_client_v2(**kwargs):
144+
captured_kwargs.update(kwargs)
145+
return FakeCohereClient()
146+
147+
monkeypatch.setattr("cohere.ClientV2", fake_client_v2)
148+
149+
conf = RerankersConf(
150+
cohere={
151+
"cohere_reranker_id": CohereRerankerConf(
152+
cohere_key=SecretStr("test-cohere-key"),
153+
model="rerank-english-v3.0",
154+
base_url="http://localhost:8000",
155+
),
156+
},
157+
)
158+
reranker_manager = RerankerManager(
159+
conf=conf,
160+
embedder_factory=cast(EmbedderFactory, FakeEmbedderFactory()),
161+
)
162+
163+
await reranker_manager.get_reranker("cohere_reranker_id")
164+
165+
assert captured_kwargs == {
166+
"api_key": "test-cohere-key",
167+
"base_url": "http://localhost:8000",
168+
}
169+
170+
171+
@pytest.mark.asyncio
172+
async def test_build_cohere_reranker_omits_base_url_when_not_configured(monkeypatch):
173+
captured_kwargs = {}
174+
175+
class FakeCohereClient:
176+
pass
177+
178+
def fake_client_v2(**kwargs):
179+
captured_kwargs.update(kwargs)
180+
return FakeCohereClient()
181+
182+
monkeypatch.setattr("cohere.ClientV2", fake_client_v2)
183+
184+
conf = RerankersConf(
185+
cohere={
186+
"cohere_reranker_id": CohereRerankerConf(
187+
cohere_key=SecretStr("test-cohere-key"),
188+
model="rerank-english-v3.0",
189+
),
190+
},
191+
)
192+
reranker_manager = RerankerManager(
193+
conf=conf,
194+
embedder_factory=cast(EmbedderFactory, FakeEmbedderFactory()),
195+
)
196+
197+
await reranker_manager.get_reranker("cohere_reranker_id")
198+
199+
assert captured_kwargs == {"api_key": "test-cohere-key"}
200+
201+
135202
@requires_sentence_transformers
136203
@pytest.mark.asyncio
137204
async def test_build_cross_encoder_rerankers(reranker_manager):

packages/server/src/memmachine_server/common/configuration/reranker_conf.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Reranker configuration models."""
22

33
from typing import ClassVar, Self
4+
from urllib.parse import urlparse
45

56
import yaml
6-
from pydantic import BaseModel, Field, PrivateAttr, SecretStr
7+
from pydantic import BaseModel, Field, PrivateAttr, SecretStr, field_validator
78

89
from memmachine_server.common.configuration.mixin_confs import (
910
MetricsFactoryIdMixin,
@@ -64,6 +65,20 @@ class CohereRerankerConf(YamlSerializableMixin):
6465
default="rerank-english-v3.0",
6566
description="Cohere rerank model",
6667
)
68+
base_url: str | None = Field(
69+
default=None,
70+
description="Cohere-compatible Rerank API base URL",
71+
)
72+
73+
@field_validator("base_url")
74+
@classmethod
75+
def validate_base_url(cls, v: str | None) -> str | None:
76+
"""Ensure the base URL includes a scheme and host."""
77+
if v is not None:
78+
parsed_url = urlparse(v)
79+
if not parsed_url.scheme or not parsed_url.netloc:
80+
raise ValueError(f"Invalid base URL: base_url={v}")
81+
return v
6782

6883

6984
class CrossEncoderRerankerConf(YamlSerializableMixin):

packages/server/src/memmachine_server/common/resource_manager/reranker_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ async def _build_cohere_reranker(self, name: str) -> Reranker:
178178
conf = self.conf.cohere[name]
179179

180180
cohere_api_key = conf.cohere_key.get_secret_value() if conf.cohere_key else None
181-
client = ClientV2(api_key=cohere_api_key)
181+
if conf.base_url is not None:
182+
client = ClientV2(api_key=cohere_api_key, base_url=conf.base_url)
183+
else:
184+
client = ClientV2(api_key=cohere_api_key)
182185
params = CohereRerankerParams(
183186
client=client,
184187
model=conf.model,

sample_configs/episodic_memory_config.cpu.sample

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ resources:
138138
config:
139139
cohere_key: <COHERE_API_KEY>
140140
model: "rerank-english-v3.0"
141+
# Optional for self-hosted Cohere-compatible rerank APIs.
142+
# Examples:
143+
# base_url: "http://localhost:8000"
141144
aws_reranker_id:
142145
provider: "amazon-bedrock"
143146
config:

sample_configs/episodic_memory_config.gpu.sample

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ resources:
139139
config:
140140
cohere_key: <COHERE_API_KEY>
141141
model: "rerank-english-v3.0"
142+
# Optional for self-hosted Cohere-compatible rerank APIs.
143+
# Examples:
144+
# base_url: "http://localhost:8000"
142145
ce_ranker_id:
143146
provider: "cross-encoder"
144147
config:

sample_configs/episodic_memory_config.nebula.sample

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ resources:
163163
config:
164164
cohere_key: <COHERE_API_KEY>
165165
model: "rerank-english-v3.0"
166+
# Optional for self-hosted Cohere-compatible rerank APIs.
167+
# Examples:
168+
# base_url: "http://localhost:8000"
166169
aws_reranker_id:
167170
provider: "amazon-bedrock"
168171
config:

0 commit comments

Comments
 (0)