Skip to content

Commit 1b658bd

Browse files
committed
fix(oci): close remaining review gaps
1 parent 687ef1e commit 1b658bd

2 files changed

Lines changed: 127 additions & 11 deletions

File tree

src/cohere/oci_client.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -771,10 +771,34 @@ def transform_request_to_oci(
771771
chat_request["temperature"] = cohere_body["temperature"]
772772
if "max_tokens" in cohere_body:
773773
chat_request["maxTokens"] = cohere_body["max_tokens"]
774+
if "k" in cohere_body:
775+
chat_request["topK"] = cohere_body["k"]
776+
if "p" in cohere_body:
777+
chat_request["topP"] = cohere_body["p"]
778+
if "seed" in cohere_body:
779+
chat_request["seed"] = cohere_body["seed"]
780+
if "stop_sequences" in cohere_body:
781+
chat_request["stopSequences"] = cohere_body["stop_sequences"]
782+
if "frequency_penalty" in cohere_body:
783+
chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"]
784+
if "presence_penalty" in cohere_body:
785+
chat_request["presencePenalty"] = cohere_body["presence_penalty"]
774786
if "preamble" in cohere_body:
775787
chat_request["preambleOverride"] = cohere_body["preamble"]
776788
if "chat_history" in cohere_body:
777789
chat_request["chatHistory"] = cohere_body["chat_history"]
790+
if "documents" in cohere_body:
791+
chat_request["documents"] = cohere_body["documents"]
792+
if "tools" in cohere_body:
793+
chat_request["tools"] = cohere_body["tools"]
794+
if "tool_results" in cohere_body:
795+
chat_request["toolResults"] = cohere_body["tool_results"]
796+
if "response_format" in cohere_body:
797+
chat_request["responseFormat"] = cohere_body["response_format"]
798+
if "safety_mode" in cohere_body:
799+
chat_request["safetyMode"] = cohere_body["safety_mode"]
800+
if "priority" in cohere_body:
801+
chat_request["priority"] = cohere_body["priority"]
778802

779803
# Handle streaming for both versions
780804
if "stream" in endpoint or cohere_body.get("stream"):
@@ -988,6 +1012,8 @@ def transform_oci_stream_wrapper(
9881012
generation_id = str(uuid.uuid4())
9891013
emitted_start = False
9901014
emitted_content_end = False
1015+
current_content_type: typing.Optional[str] = None
1016+
current_content_index = 0
9911017
final_finish_reason = "COMPLETE"
9921018
final_usage: typing.Optional[typing.Dict[str, typing.Any]] = None
9931019
full_v1_text = ""
@@ -1000,18 +1026,23 @@ def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes:
10001026
def _emit_v1_event(event: typing.Dict[str, typing.Any]) -> bytes:
10011027
return json.dumps(event).encode("utf-8") + b"\n"
10021028

1029+
def _current_v2_content_type(oci_event: typing.Dict[str, typing.Any]) -> str:
1030+
message = oci_event.get("message")
1031+
if isinstance(message, dict):
1032+
content_list = message.get("content")
1033+
if content_list and isinstance(content_list, list) and len(content_list) > 0:
1034+
oci_type = content_list[0].get("type", "TEXT").upper()
1035+
if oci_type == "THINKING":
1036+
return "thinking"
1037+
return "text"
1038+
10031039
def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]:
1004-
nonlocal emitted_start, emitted_content_end, final_finish_reason, final_usage
1040+
nonlocal emitted_start, emitted_content_end, current_content_type, current_content_index
1041+
nonlocal final_finish_reason, final_usage
1042+
1043+
event_content_type = _current_v2_content_type(oci_event)
10051044

10061045
if not emitted_start:
1007-
content_type = "text"
1008-
message = oci_event.get("message")
1009-
if isinstance(message, dict):
1010-
content_list = message.get("content")
1011-
if content_list and isinstance(content_list, list) and len(content_list) > 0:
1012-
oci_type = content_list[0].get("type", "TEXT").upper()
1013-
if oci_type == "THINKING":
1014-
content_type = "thinking"
10151046

10161047
yield _emit_v2_event(
10171048
{
@@ -1023,15 +1054,30 @@ def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10231054
yield _emit_v2_event(
10241055
{
10251056
"type": "content-start",
1026-
"index": 0,
1027-
"delta": {"message": {"content": {"type": content_type}}},
1057+
"index": current_content_index,
1058+
"delta": {"message": {"content": {"type": event_content_type}}},
10281059
}
10291060
)
10301061
emitted_start = True
1062+
current_content_type = event_content_type
1063+
elif current_content_type != event_content_type:
1064+
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1065+
current_content_index += 1
1066+
yield _emit_v2_event(
1067+
{
1068+
"type": "content-start",
1069+
"index": current_content_index,
1070+
"delta": {"message": {"content": {"type": event_content_type}}},
1071+
}
1072+
)
1073+
current_content_type = event_content_type
1074+
emitted_content_end = False
10311075

10321076
for cohere_event in typing.cast(
10331077
typing.List[typing.Dict[str, typing.Any]], transform_stream_event(endpoint, oci_event, is_v2=True)
10341078
):
1079+
if "index" in cohere_event:
1080+
cohere_event = {**cohere_event, "index": current_content_index}
10351081
if cohere_event["type"] == "content-end":
10361082
emitted_content_end = True
10371083
final_finish_reason = oci_event.get("finishReason", final_finish_reason)
@@ -1069,6 +1115,7 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
10691115
yield _emit_v1_event(
10701116
{
10711117
"event_type": "stream-end",
1118+
"finish_reason": final_v1_finish_reason,
10721119
"response": {
10731120
"text": full_v1_text,
10741121
"generation_id": generation_id,

tests/test_oci_client.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,46 @@ def test_transform_chat_request_optional_params(self):
815815
self.assertEqual(chat_req["toolChoice"], "REQUIRED")
816816
self.assertEqual(chat_req["priority"], 7)
817817

818+
def test_transform_v1_chat_request_optional_params(self):
819+
"""Test V1 chat request forwards the supported optional params."""
820+
from cohere.oci_client import transform_request_to_oci
821+
822+
body = {
823+
"model": "command-r-08-2024",
824+
"message": "Hi",
825+
"max_tokens": 100,
826+
"temperature": 0.7,
827+
"k": 10,
828+
"p": 0.8,
829+
"seed": 123,
830+
"stop_sequences": ["END"],
831+
"frequency_penalty": 0.5,
832+
"presence_penalty": 0.2,
833+
"documents": [{"title": "Doc", "text": "Body"}],
834+
"tools": [{"name": "lookup"}],
835+
"tool_results": [{"call": {"name": "lookup"}}],
836+
"response_format": {"type": "json_object"},
837+
"safety_mode": "NONE",
838+
"priority": 4,
839+
}
840+
result = transform_request_to_oci("chat", body, "compartment-123")
841+
842+
chat_req = result["chatRequest"]
843+
self.assertEqual(chat_req["maxTokens"], 100)
844+
self.assertEqual(chat_req["temperature"], 0.7)
845+
self.assertEqual(chat_req["topK"], 10)
846+
self.assertEqual(chat_req["topP"], 0.8)
847+
self.assertEqual(chat_req["seed"], 123)
848+
self.assertEqual(chat_req["stopSequences"], ["END"])
849+
self.assertEqual(chat_req["frequencyPenalty"], 0.5)
850+
self.assertEqual(chat_req["presencePenalty"], 0.2)
851+
self.assertEqual(chat_req["documents"], [{"title": "Doc", "text": "Body"}])
852+
self.assertEqual(chat_req["tools"], [{"name": "lookup"}])
853+
self.assertEqual(chat_req["toolResults"], [{"call": {"name": "lookup"}}])
854+
self.assertEqual(chat_req["responseFormat"], {"type": "json_object"})
855+
self.assertEqual(chat_req["safetyMode"], "NONE")
856+
self.assertEqual(chat_req["priority"], 4)
857+
818858
def test_transform_chat_request_tool_message_fields(self):
819859
"""Test tool message fields are converted to OCI names."""
820860
from cohere.oci_client import transform_request_to_oci
@@ -1042,6 +1082,34 @@ def test_stream_wrapper_emits_full_event_lifecycle(self):
10421082
self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "text")
10431083
self.assertEqual(events[5]["delta"]["finish_reason"], "COMPLETE")
10441084

1085+
def test_stream_wrapper_emits_new_content_block_on_thinking_transition(self):
1086+
"""Test V2 streams emit a new content block when transitioning from thinking to text."""
1087+
import json
1088+
from cohere.oci_client import transform_oci_stream_wrapper
1089+
1090+
chunks = [
1091+
b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n',
1092+
b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}, "finishReason": "COMPLETE"}\n',
1093+
b"data: [DONE]\n",
1094+
]
1095+
1096+
events = []
1097+
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True):
1098+
line = raw.decode("utf-8").strip()
1099+
if line.startswith("data: "):
1100+
events.append(json.loads(line[6:]))
1101+
1102+
self.assertEqual(events[1]["type"], "content-start")
1103+
self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "thinking")
1104+
self.assertEqual(events[2]["type"], "content-delta")
1105+
self.assertEqual(events[2]["index"], 0)
1106+
self.assertEqual(events[3], {"type": "content-end", "index": 0})
1107+
self.assertEqual(events[4]["type"], "content-start")
1108+
self.assertEqual(events[4]["index"], 1)
1109+
self.assertEqual(events[4]["delta"]["message"]["content"]["type"], "text")
1110+
self.assertEqual(events[5]["type"], "content-delta")
1111+
self.assertEqual(events[5]["index"], 1)
1112+
10451113
def test_stream_wrapper_skips_malformed_json_with_warning(self):
10461114
"""Test that malformed JSON in SSE stream is skipped (not silently swallowed)."""
10471115
from cohere.oci_client import transform_oci_stream_wrapper
@@ -1080,6 +1148,7 @@ def test_v1_stream_wrapper_preserves_finish_reason_in_stream_end(self):
10801148
]
10811149

10821150
self.assertEqual(events[2]["event_type"], "stream-end")
1151+
self.assertEqual(events[2]["finish_reason"], "MAX_TOKENS")
10831152
self.assertEqual(events[2]["response"]["text"], "Hello world")
10841153
self.assertEqual(events[2]["response"]["finish_reason"], "MAX_TOKENS")
10851154

0 commit comments

Comments
 (0)