Skip to content

Commit a284ea8

Browse files
fede-kamelclaude
andcommitted
Fix OCI client V2 support and address copilot issues
This commit addresses all copilot feedback and fixes V2 API support: 1. Fixed V2 embed response format - V2 expects embeddings as dict with type keys (float, int8, etc.) - Added is_v2_client parameter to properly detect V2 mode - Updated transform_oci_response_to_cohere to preserve dict structure for V2 2. Fixed V2 streaming format - V2 SDK expects SSE format with "data: " prefix and double newline - Fixed text extraction from OCI V2 events (nested in message.content[0].text) - Added proper content-delta and content-end event types for V2 - Updated transform_oci_stream_wrapper to output correct format based on is_v2 3. Fixed stream [DONE] signal handling - Changed from break to return to stop generator completely - Prevents further chunk processing after [DONE] 4. Added skip decorators with clear explanations - OCI on-demand models don't support multiple embedding types - OCI TEXT_GENERATION models require fine-tuning (not available on-demand) - OCI TEXT_RERANK models require fine-tuning (not available on-demand) 5. Added comprehensive V2 tests - test_embed_v2 with embedding dimension validation - test_embed_with_model_prefix_v2 - test_chat_v2 - test_chat_stream_v2 with text extraction validation All 17 tests now pass with 7 properly documented skips. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 83e2375 commit a284ea8

File tree

2 files changed

+150
-23
lines changed

2 files changed

+150
-23
lines changed

src/cohere/oci_client.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
oci_config=oci_config,
120120
oci_region=oci_region,
121121
oci_compartment_id=oci_compartment_id,
122+
is_v2_client=False,
122123
),
123124
timeout=timeout,
124125
),
@@ -183,6 +184,7 @@ def __init__(
183184
oci_config=oci_config,
184185
oci_region=oci_region,
185186
oci_compartment_id=oci_compartment_id,
187+
is_v2_client=True,
186188
),
187189
timeout=timeout,
188190
),
@@ -270,6 +272,7 @@ def get_event_hooks(
270272
oci_config: typing.Dict[str, typing.Any],
271273
oci_region: str,
272274
oci_compartment_id: str,
275+
is_v2_client: bool = False,
273276
) -> typing.Dict[str, typing.List[EventHook]]:
274277
"""
275278
Create httpx event hooks for OCI request/response transformation.
@@ -278,6 +281,7 @@ def get_event_hooks(
278281
oci_config: OCI configuration dictionary
279282
oci_region: OCI region (e.g., "us-chicago-1")
280283
oci_compartment_id: OCI compartment OCID
284+
is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False)
281285
282286
Returns:
283287
Dictionary of event hooks for httpx
@@ -288,6 +292,7 @@ def get_event_hooks(
288292
oci_config=oci_config,
289293
oci_region=oci_region,
290294
oci_compartment_id=oci_compartment_id,
295+
is_v2_client=is_v2_client,
291296
),
292297
],
293298
"response": [map_response_from_oci()],
@@ -298,6 +303,7 @@ def map_request_to_oci(
298303
oci_config: typing.Dict[str, typing.Any],
299304
oci_region: str,
300305
oci_compartment_id: str,
306+
is_v2_client: bool = False,
301307
) -> EventHook:
302308
"""
303309
Create event hook that transforms Cohere requests to OCI format and signs them.
@@ -306,6 +312,7 @@ def map_request_to_oci(
306312
oci_config: OCI configuration dictionary
307313
oci_region: OCI region
308314
oci_compartment_id: OCI compartment OCID
315+
is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False)
309316
310317
Returns:
311318
Event hook function for httpx
@@ -393,6 +400,10 @@ def _event_hook(request: httpx.Request) -> None:
393400
request.extensions["endpoint"] = endpoint
394401
request.extensions["cohere_body"] = body
395402
request.extensions["is_stream"] = "stream" in endpoint or body.get("stream", False)
403+
# Store V2 detection for streaming event transformation
404+
# For chat, detect V2 by presence of "messages" field (V2) vs "message" field (V1)
405+
# For other endpoints (embed, rerank), use the client type
406+
request.extensions["is_v2"] = is_v2_client or ("messages" in body)
396407

397408
return _event_hook
398409

@@ -408,6 +419,7 @@ def map_response_from_oci() -> EventHook:
408419
def _hook(response: httpx.Response) -> None:
409420
endpoint = response.request.extensions["endpoint"]
410421
is_stream = response.request.extensions.get("is_stream", False)
422+
is_v2 = response.request.extensions.get("is_v2", False)
411423

412424
output: typing.Iterator[bytes]
413425

@@ -419,7 +431,7 @@ def _hook(response: httpx.Response) -> None:
419431
# For streaming responses, wrap the stream with a transformer
420432
if is_stream:
421433
original_stream = response.stream
422-
transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint)
434+
transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint, is_v2)
423435
response.stream = Streamer(transformed_stream)
424436
# Reset consumption flags
425437
if hasattr(response, "_content"):
@@ -430,7 +442,7 @@ def _hook(response: httpx.Response) -> None:
430442

431443
# Handle non-streaming responses
432444
oci_response = json.loads(response.read())
433-
cohere_response = transform_oci_response_to_cohere(endpoint, oci_response)
445+
cohere_response = transform_oci_response_to_cohere(endpoint, oci_response, is_v2)
434446
output = iter([json.dumps(cohere_response).encode("utf-8")])
435447

436448
response.stream = Streamer(output)
@@ -687,23 +699,31 @@ def transform_request_to_oci(
687699

688700

689701
def transform_oci_response_to_cohere(
690-
endpoint: str, oci_response: typing.Dict[str, typing.Any]
702+
endpoint: str, oci_response: typing.Dict[str, typing.Any], is_v2: bool = False
691703
) -> typing.Dict[str, typing.Any]:
692704
"""
693705
Transform OCI response to Cohere format.
694706
695707
Args:
696708
endpoint: Cohere endpoint name
697709
oci_response: OCI response body
710+
is_v2: Whether this is a V2 API request
698711
699712
Returns:
700713
Transformed response in Cohere format
701714
"""
702715
if endpoint == "embed":
703716
# OCI returns embeddings in "embeddings" field, may have multiple types
704717
embeddings_data = oci_response.get("embeddings", {})
705-
# For now, handle float embeddings (most common case)
706-
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
718+
719+
# V2 expects embeddings as a dict with type keys (float, int8, etc.)
720+
# V1 expects embeddings as a direct list
721+
if is_v2:
722+
# Keep the dict structure for V2
723+
embeddings = embeddings_data if isinstance(embeddings_data, dict) else {"float": embeddings_data}
724+
else:
725+
# Extract just the float embeddings for V1
726+
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
707727

708728
# Build proper meta structure
709729
meta = {
@@ -828,14 +848,15 @@ def transform_oci_response_to_cohere(
828848

829849

830850
def transform_oci_stream_wrapper(
831-
stream: typing.Iterator[bytes], endpoint: str
851+
stream: typing.Iterator[bytes], endpoint: str, is_v2: bool = False
832852
) -> typing.Iterator[bytes]:
833853
"""
834854
Wrap OCI stream and transform events to Cohere format.
835855
836856
Args:
837857
stream: Original OCI stream iterator
838858
endpoint: Cohere endpoint name
859+
is_v2: Whether this is a V2 API request
839860
840861
Yields:
841862
Bytes of transformed streaming events
@@ -855,8 +876,12 @@ def transform_oci_stream_wrapper(
855876

856877
try:
857878
oci_event = json.loads(data_str)
858-
cohere_event = transform_stream_event(endpoint, oci_event)
859-
yield json.dumps(cohere_event).encode("utf-8") + b"\n"
879+
cohere_event = transform_stream_event(endpoint, oci_event, is_v2)
880+
# V2 expects SSE format with "data: " prefix and double newline, V1 expects plain JSON
881+
if is_v2:
882+
yield b"data: " + json.dumps(cohere_event).encode("utf-8") + b"\n\n"
883+
else:
884+
yield json.dumps(cohere_event).encode("utf-8") + b"\n"
860885
except json.JSONDecodeError:
861886
continue
862887

@@ -891,26 +916,62 @@ def transform_oci_stream_response(
891916

892917

893918
def transform_stream_event(
894-
endpoint: str, oci_event: typing.Dict[str, typing.Any]
919+
endpoint: str, oci_event: typing.Dict[str, typing.Any], is_v2: bool = False
895920
) -> typing.Dict[str, typing.Any]:
896921
"""
897922
Transform individual OCI stream event to Cohere format.
898923
899924
Args:
900925
endpoint: Cohere endpoint name
901926
oci_event: OCI stream event
927+
is_v2: Whether this is a V2 API request
902928
903929
Returns:
904930
Transformed event in Cohere format
905931
"""
906932
if endpoint in ["chat_stream", "chat"]:
907-
return {
908-
"event_type": "text-generation",
909-
"text": oci_event.get("text", ""),
910-
"is_finished": oci_event.get("isFinished", False),
911-
}
933+
if is_v2:
934+
# V2 API format: OCI returns full message structure in each event
935+
# Extract text from nested structure: message.content[0].text
936+
text = ""
937+
if "message" in oci_event and "content" in oci_event["message"]:
938+
content_list = oci_event["message"]["content"]
939+
if content_list and isinstance(content_list, list) and len(content_list) > 0:
940+
first_content = content_list[0]
941+
if "text" in first_content:
942+
text = first_content["text"]
943+
944+
is_finished = "finishReason" in oci_event
945+
946+
if is_finished:
947+
# Final event - use content-end type
948+
return {
949+
"type": "content-end",
950+
"index": 0,
951+
}
952+
else:
953+
# Content delta event
954+
return {
955+
"type": "content-delta",
956+
"index": 0,
957+
"delta": {
958+
"message": {
959+
"content": {
960+
"text": text,
961+
}
962+
}
963+
},
964+
}
965+
else:
966+
# V1 API format
967+
return {
968+
"event_type": "text-generation",
969+
"text": oci_event.get("text", ""),
970+
"is_finished": oci_event.get("isFinished", False),
971+
}
912972

913973
elif endpoint in ["generate_stream", "generate"]:
974+
# Generate only supports V1
914975
return {
915976
"event_type": "text-generation",
916977
"text": oci_event.get("text", ""),

tests/test_oci_client.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def test_embed_with_model_prefix(self):
6262
self.assertIsNotNone(response.embeddings)
6363
self.assertEqual(len(response.embeddings), 1)
6464

65-
@unittest.skip("Multiple embedding types not yet implemented for OCI")
65+
@unittest.skip(
66+
"OCI on-demand models don't support multiple embedding types in a single call. "
67+
"The embedding_types parameter in OCI accepts a single value, not a list."
68+
)
6669
def test_embed_multiple_types(self):
6770
"""Test embedding with multiple embedding types."""
6871
response = self.client.embed(
@@ -114,7 +117,10 @@ def test_chat_stream(self):
114117
text_events = [e for e in events if hasattr(e, "text") and e.text]
115118
self.assertTrue(len(text_events) > 0)
116119

117-
@unittest.skip("OCI TEXT_GENERATION models are finetune base models - not callable via on-demand inference")
120+
@unittest.skip(
121+
"OCI TEXT_GENERATION models are finetune base models, not available via on-demand inference. "
122+
"Only CHAT models (command-r, command-a) support on-demand inference on OCI."
123+
)
118124
def test_generate(self):
119125
"""Test text generation with OCI."""
120126
response = self.client.generate(
@@ -128,7 +134,10 @@ def test_generate(self):
128134
self.assertTrue(len(response.generations) > 0)
129135
self.assertIsNotNone(response.generations[0].text)
130136

131-
@unittest.skip("OCI TEXT_GENERATION models are finetune base models - not callable via on-demand inference")
137+
@unittest.skip(
138+
"OCI TEXT_GENERATION models are finetune base models, not available via on-demand inference. "
139+
"Only CHAT models (command-r, command-a) support on-demand inference on OCI."
140+
)
132141
def test_generate_stream(self):
133142
"""Test streaming text generation with OCI."""
134143
events = []
@@ -141,7 +150,10 @@ def test_generate_stream(self):
141150

142151
self.assertTrue(len(events) > 0)
143152

144-
@unittest.skip("OCI TEXT_RERANK models are base models - not callable via on-demand inference")
153+
@unittest.skip(
154+
"OCI TEXT_RERANK models are base models, not available via on-demand inference. "
155+
"These models require fine-tuning and deployment before use on OCI."
156+
)
145157
def test_rerank(self):
146158
"""Test reranking with OCI."""
147159
query = "What is the capital of France?"
@@ -185,17 +197,34 @@ def setUp(self):
185197
oci_profile=profile,
186198
)
187199

188-
@unittest.skip("Embed API is identical in V1 and V2 - use V1 client for embed")
189200
def test_embed_v2(self):
190-
"""Test embedding with v2 client (same as V1 for embed)."""
201+
"""Test embedding with v2 client."""
191202
response = self.client.embed(
192203
model="embed-english-v3.0",
193-
texts=["Hello from v2"],
204+
texts=["Hello from v2", "Second text"],
194205
input_type="search_document",
195206
)
196207

197208
self.assertIsNotNone(response)
198209
self.assertIsNotNone(response.embeddings)
210+
# V2 returns embeddings as a dict with "float" key
211+
self.assertIsNotNone(response.embeddings.float_)
212+
self.assertEqual(len(response.embeddings.float_), 2)
213+
# Verify embedding dimensions (1024 for embed-english-v3.0)
214+
self.assertEqual(len(response.embeddings.float_[0]), 1024)
215+
216+
def test_embed_with_model_prefix_v2(self):
217+
"""Test embedding with 'cohere.' model prefix on v2 client."""
218+
response = self.client.embed(
219+
model="cohere.embed-english-v3.0",
220+
texts=["Test with prefix"],
221+
input_type="search_document",
222+
)
223+
224+
self.assertIsNotNone(response)
225+
self.assertIsNotNone(response.embeddings)
226+
self.assertIsNotNone(response.embeddings.float_)
227+
self.assertEqual(len(response.embeddings.float_), 1)
199228

200229
def test_chat_v2(self):
201230
"""Test chat with v2 client."""
@@ -207,7 +236,41 @@ def test_chat_v2(self):
207236
self.assertIsNotNone(response)
208237
self.assertIsNotNone(response.message)
209238

210-
@unittest.skip("OCI TEXT_RERANK models are base models - not callable via on-demand inference")
239+
def test_chat_stream_v2(self):
240+
"""Test streaming chat with v2 client."""
241+
events = []
242+
for event in self.client.chat_stream(
243+
model="command-a-03-2025",
244+
messages=[{"role": "user", "content": "Count from 1 to 3"}],
245+
):
246+
events.append(event)
247+
248+
self.assertTrue(len(events) > 0)
249+
# Verify we received content-delta events with text
250+
content_delta_events = [e for e in events if hasattr(e, "type") and e.type == "content-delta"]
251+
self.assertTrue(len(content_delta_events) > 0)
252+
253+
# Verify we can extract text from events
254+
full_text = ""
255+
for event in events:
256+
if (
257+
hasattr(event, "delta")
258+
and event.delta
259+
and hasattr(event.delta, "message")
260+
and event.delta.message
261+
and hasattr(event.delta.message, "content")
262+
and event.delta.message.content
263+
and hasattr(event.delta.message.content, "text")
264+
):
265+
full_text += event.delta.message.content.text
266+
267+
# Should have received some text
268+
self.assertTrue(len(full_text) > 0)
269+
270+
@unittest.skip(
271+
"OCI TEXT_RERANK models are base models, not available via on-demand inference. "
272+
"These models require fine-tuning and deployment before use on OCI."
273+
)
211274
def test_rerank_v2(self):
212275
"""Test reranking with v2 client."""
213276
response = self.client.rerank(
@@ -378,7 +441,10 @@ def test_command_r_plus(self):
378441
)
379442
self.assertIsNotNone(response.text)
380443

381-
@unittest.skip("OCI TEXT_RERANK models are base models - not callable via on-demand inference")
444+
@unittest.skip(
445+
"OCI TEXT_RERANK models are base models, not available via on-demand inference. "
446+
"These models require fine-tuning and deployment before use on OCI."
447+
)
382448
def test_rerank_v3(self):
383449
"""Test rerank-english-v3.0 model."""
384450
response = self.client.rerank(

0 commit comments

Comments
 (0)