Skip to content

Commit 12938af

Browse files
committed
fix: address OCI streaming review feedback
1 parent c22dbfb commit 12938af

2 files changed

Lines changed: 159 additions & 17 deletions

File tree

src/cohere/oci_client.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def map_request_to_oci(
384384
"OCI config profile is missing 'key_file'. "
385385
"Session-based auth requires a key_file entry in your OCI config profile."
386386
)
387+
key_file = os.path.expanduser(key_file)
387388
private_key = oci.signer.load_private_key_from_file(key_file)
388389

389390
signer = oci.auth.signers.SecurityTokenSigner(
@@ -411,7 +412,6 @@ def _event_hook(request: httpx.Request) -> None:
411412
url = get_oci_url(
412413
region=oci_region,
413414
endpoint=endpoint,
414-
stream="stream" in endpoint or body.get("stream", False),
415415
)
416416

417417
# Transform request body to OCI format
@@ -517,16 +517,13 @@ def _hook(response: httpx.Response) -> None:
517517
def get_oci_url(
518518
region: str,
519519
endpoint: str,
520-
stream: bool = False,
521520
) -> str:
522521
"""
523522
Map Cohere endpoints to OCI Generative AI endpoints.
524523
525524
Args:
526525
region: OCI region (e.g., "us-chicago-1")
527526
endpoint: Cohere endpoint name
528-
stream: Whether this is a streaming request
529-
530527
Returns:
531528
Full OCI Generative AI endpoint URL
532529
"""
@@ -972,6 +969,21 @@ def transform_oci_stream_wrapper(
972969

973970
generation_id = str(uuid.uuid4())
974971
emitted_start = False
972+
v1_text_parts: typing.List[str] = []
973+
v1_finish_reason = "COMPLETE"
974+
v1_response: typing.Dict[str, typing.Any] = {
975+
"text": "",
976+
"generation_id": generation_id,
977+
"response_id": None,
978+
"citations": [],
979+
"documents": [],
980+
"is_search_required": None,
981+
"search_queries": [],
982+
"search_results": [],
983+
"chat_history": [],
984+
"meta": {"api_version": {"version": "1"}},
985+
}
986+
v2_message_end_delta: typing.Dict[str, typing.Any] = {}
975987
buffer = b""
976988
for chunk in stream:
977989
buffer += chunk
@@ -982,10 +994,25 @@ def transform_oci_stream_wrapper(
982994
if line.startswith("data: "):
983995
data_str = line[6:] # Remove "data: " prefix
984996
if data_str.strip() == "[DONE]":
985-
# Emit message-end event for V2 before stopping
986997
if is_v2:
987-
message_end_event = {"type": "message-end"}
998+
message_end_event: typing.Dict[str, typing.Any] = {
999+
"type": "message-end",
1000+
"id": generation_id,
1001+
}
1002+
if v2_message_end_delta:
1003+
message_end_event["delta"] = v2_message_end_delta
9881004
yield b"data: " + json.dumps(message_end_event).encode("utf-8") + b"\n\n"
1005+
elif endpoint in ["chat_stream", "chat"]:
1006+
stream_end_event = {
1007+
"event_type": "stream-end",
1008+
"finish_reason": v1_finish_reason,
1009+
"response": {
1010+
**v1_response,
1011+
"text": "".join(v1_text_parts),
1012+
"finish_reason": v1_finish_reason,
1013+
},
1014+
}
1015+
yield json.dumps(stream_end_event).encode("utf-8") + b"\n"
9891016
# Return to stop the generator completely
9901017
return
9911018

@@ -1023,10 +1050,61 @@ def transform_oci_stream_wrapper(
10231050
yield b"data: " + json.dumps(content_start).encode("utf-8") + b"\n\n"
10241051
emitted_start = True
10251052

1026-
cohere_event = transform_stream_event(endpoint, oci_event, is_v2)
10271053
if is_v2:
1028-
yield b"data: " + json.dumps(cohere_event).encode("utf-8") + b"\n\n"
1054+
if endpoint in ["chat_stream", "chat"] and "finishReason" in oci_event:
1055+
content_event = transform_stream_event(
1056+
endpoint,
1057+
{k: v for k, v in oci_event.items() if k != "finishReason"},
1058+
is_v2,
1059+
)
1060+
content_payload = (
1061+
content_event.get("delta", {})
1062+
.get("message", {})
1063+
.get("content", {})
1064+
)
1065+
if content_payload.get("text") or content_payload.get("thinking"):
1066+
yield b"data: " + json.dumps(content_event).encode("utf-8") + b"\n\n"
1067+
1068+
usage_data = oci_event.get("usage", {})
1069+
usage: typing.Dict[str, typing.Any] = {
1070+
"tokens": {
1071+
"input_tokens": usage_data.get("inputTokens", 0),
1072+
"output_tokens": usage_data.get("completionTokens", 0),
1073+
}
1074+
}
1075+
if usage_data.get("inputTokens") or usage_data.get("completionTokens"):
1076+
usage["billed_units"] = {
1077+
"input_tokens": usage_data.get("inputTokens", 0),
1078+
"output_tokens": usage_data.get("completionTokens", 0),
1079+
}
1080+
v2_message_end_delta = {
1081+
"finish_reason": oci_event.get("finishReason"),
1082+
"usage": usage,
1083+
}
1084+
1085+
content_end_event = {"type": "content-end", "index": 0}
1086+
yield b"data: " + json.dumps(content_end_event).encode("utf-8") + b"\n\n"
1087+
else:
1088+
cohere_event = transform_stream_event(endpoint, oci_event, is_v2)
1089+
yield b"data: " + json.dumps(cohere_event).encode("utf-8") + b"\n\n"
10291090
else:
1091+
if endpoint in ["chat_stream", "chat"]:
1092+
text = oci_event.get("text", "")
1093+
if text:
1094+
v1_text_parts.append(text)
1095+
if "finishReason" in oci_event:
1096+
v1_finish_reason = oci_event.get("finishReason", v1_finish_reason)
1097+
if "chatHistory" in oci_event:
1098+
v1_response["chat_history"] = oci_event.get("chatHistory", [])
1099+
if "citations" in oci_event:
1100+
v1_response["citations"] = oci_event.get("citations", [])
1101+
if "documents" in oci_event:
1102+
v1_response["documents"] = oci_event.get("documents", [])
1103+
if "searchQueries" in oci_event:
1104+
v1_response["search_queries"] = oci_event.get("searchQueries", [])
1105+
if "searchResults" in oci_event:
1106+
v1_response["search_results"] = oci_event.get("searchResults", [])
1107+
cohere_event = transform_stream_event(endpoint, oci_event, is_v2)
10301108
yield json.dumps(cohere_event).encode("utf-8") + b"\n"
10311109
except Exception as e:
10321110
raise RuntimeError(

tests/test_oci_client.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import cohere
18+
from cohere.errors import NotFoundError
1819

1920

2021
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
@@ -467,11 +468,14 @@ def test_embed_english_v3(self):
467468

468469
def test_embed_light_v3(self):
469470
"""Test embed-english-light-v3.0 model."""
470-
response = self.client.embed(
471-
model="embed-english-light-v3.0",
472-
texts=["Test"],
473-
input_type="search_document",
474-
)
471+
try:
472+
response = self.client.embed(
473+
model="embed-english-light-v3.0",
474+
texts=["Test"],
475+
input_type="search_document",
476+
)
477+
except NotFoundError:
478+
self.skipTest("embed-english-light-v3.0 is not available in this OCI region/profile")
475479
self.assertIsNotNone(response.embeddings)
476480
self.assertEqual(len(response.embeddings[0]), 384)
477481

@@ -709,7 +713,7 @@ def test_get_oci_url_known_endpoints(self):
709713
url = get_oci_url("us-chicago-1", "chat")
710714
self.assertIn("/actions/chat", url)
711715

712-
url = get_oci_url("us-chicago-1", "chat_stream", stream=True)
716+
url = get_oci_url("us-chicago-1", "chat_stream")
713717
self.assertIn("/actions/chat", url)
714718

715719
def test_get_oci_url_unknown_endpoint_raises(self):
@@ -744,7 +748,7 @@ def test_stream_wrapper_emits_full_event_lifecycle(self):
744748

745749
chunks = [
746750
b'data: {"message": {"content": [{"type": "TEXT", "text": "Hello"}]}}\n',
747-
b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE"}\n',
751+
b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE", "usage": {"inputTokens": 3, "completionTokens": 2}}\n',
748752
b"data: [DONE]\n",
749753
]
750754

@@ -758,13 +762,19 @@ def test_stream_wrapper_emits_full_event_lifecycle(self):
758762
self.assertEqual(event_types[0], "message-start")
759763
self.assertEqual(event_types[1], "content-start")
760764
self.assertEqual(event_types[2], "content-delta")
761-
self.assertEqual(event_types[3], "content-end")
762-
self.assertEqual(event_types[4], "message-end")
765+
self.assertEqual(event_types[3], "content-delta")
766+
self.assertEqual(event_types[4], "content-end")
767+
self.assertEqual(event_types[5], "message-end")
763768

764769
self.assertIn("id", events[0])
765770
self.assertEqual(events[0]["delta"]["message"]["role"], "assistant")
766771
self.assertEqual(events[1]["index"], 0)
767772
self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "text")
773+
self.assertEqual(events[2]["delta"]["message"]["content"]["text"], "Hello")
774+
self.assertEqual(events[3]["delta"]["message"]["content"]["text"], " world")
775+
self.assertEqual(events[5]["delta"]["finish_reason"], "COMPLETE")
776+
self.assertEqual(events[5]["delta"]["usage"]["tokens"]["input_tokens"], 3)
777+
self.assertEqual(events[5]["delta"]["usage"]["tokens"]["output_tokens"], 2)
768778

769779
def test_stream_wrapper_skips_malformed_json_with_warning(self):
770780
"""Test that malformed JSON in SSE streams is skipped with a warning."""
@@ -779,6 +789,60 @@ def test_stream_wrapper_skips_malformed_json_with_warning(self):
779789
events = list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True))
780790
self.assertEqual(len(events), 4)
781791

792+
def test_v1_stream_wrapper_emits_stream_end(self):
793+
"""Test that V1 chat streams end with a stream-end event containing the response payload."""
794+
import json
795+
from cohere.oci_client import transform_oci_stream_wrapper
796+
797+
chunks = [
798+
b'data: {"text": "Hello", "isFinished": false}\n',
799+
b'data: {"text": " world", "isFinished": true, "finishReason": "COMPLETE"}\n',
800+
b"data: [DONE]\n",
801+
]
802+
803+
events = [
804+
json.loads(raw.decode("utf-8"))
805+
for raw in transform_oci_stream_wrapper(iter(chunks), "chat_stream", is_v2=False)
806+
]
807+
808+
self.assertEqual(events[0]["event_type"], "text-generation")
809+
self.assertEqual(events[1]["event_type"], "text-generation")
810+
self.assertEqual(events[2]["event_type"], "stream-end")
811+
self.assertEqual(events[2]["finish_reason"], "COMPLETE")
812+
self.assertEqual(events[2]["response"]["text"], "Hello world")
813+
814+
def test_session_auth_expands_key_file_path(self):
815+
"""Test that session-based auth expands key_file paths before loading them."""
816+
from unittest.mock import MagicMock, patch
817+
from cohere.oci_client import map_request_to_oci
818+
819+
mock_private_key = object()
820+
mock_security_token_signer = MagicMock()
821+
mock_oci = MagicMock()
822+
mock_oci.signer.load_private_key_from_file.return_value = mock_private_key
823+
mock_oci.auth.signers.SecurityTokenSigner.return_value = mock_security_token_signer
824+
825+
oci_config = {
826+
"security_token_file": "~/.oci/sessions/TEST/token",
827+
"key_file": "~/.oci/sessions/TEST/oci_api_key.pem",
828+
}
829+
830+
with patch("cohere.oci_client.lazy_oci", return_value=mock_oci), patch(
831+
"builtins.open", create=True
832+
) as mock_open:
833+
mock_open.return_value.__enter__.return_value.read.return_value = "token"
834+
with patch("os.path.expanduser", side_effect=lambda p: p.replace("~", "/Users/test")):
835+
with patch("cohere.oci_client.transform_request_to_oci", return_value={"compartmentId": "c"}):
836+
hook = map_request_to_oci(oci_config, "us-chicago-1", "compartment-123", is_v2_client=False)
837+
request = MagicMock()
838+
request.url.path = "/v1/chat"
839+
request.read.return_value = b'{"message":"hi"}'
840+
request.method = "POST"
841+
request.extensions = {}
842+
hook(request)
843+
844+
mock_oci.signer.load_private_key_from_file.assert_called_with("/Users/test/.oci/sessions/TEST/oci_api_key.pem")
845+
782846
def test_stream_wrapper_raises_on_transform_error(self):
783847
"""Test that stream transformation errors are re-raised with OCI-specific context."""
784848
from cohere.oci_client import transform_oci_stream_wrapper

0 commit comments

Comments
 (0)