@@ -771,10 +771,34 @@ def transform_request_to_oci(
771771 chat_request ["temperature" ] = cohere_body ["temperature" ]
772772 if "max_tokens" in cohere_body :
773773 chat_request ["maxTokens" ] = cohere_body ["max_tokens" ]
774+ if "k" in cohere_body :
775+ chat_request ["topK" ] = cohere_body ["k" ]
776+ if "p" in cohere_body :
777+ chat_request ["topP" ] = cohere_body ["p" ]
778+ if "seed" in cohere_body :
779+ chat_request ["seed" ] = cohere_body ["seed" ]
780+ if "stop_sequences" in cohere_body :
781+ chat_request ["stopSequences" ] = cohere_body ["stop_sequences" ]
782+ if "frequency_penalty" in cohere_body :
783+ chat_request ["frequencyPenalty" ] = cohere_body ["frequency_penalty" ]
784+ if "presence_penalty" in cohere_body :
785+ chat_request ["presencePenalty" ] = cohere_body ["presence_penalty" ]
774786 if "preamble" in cohere_body :
775787 chat_request ["preambleOverride" ] = cohere_body ["preamble" ]
776788 if "chat_history" in cohere_body :
777789 chat_request ["chatHistory" ] = cohere_body ["chat_history" ]
790+ if "documents" in cohere_body :
791+ chat_request ["documents" ] = cohere_body ["documents" ]
792+ if "tools" in cohere_body :
793+ chat_request ["tools" ] = cohere_body ["tools" ]
794+ if "tool_results" in cohere_body :
795+ chat_request ["toolResults" ] = cohere_body ["tool_results" ]
796+ if "response_format" in cohere_body :
797+ chat_request ["responseFormat" ] = cohere_body ["response_format" ]
798+ if "safety_mode" in cohere_body :
799+ chat_request ["safetyMode" ] = cohere_body ["safety_mode" ]
800+ if "priority" in cohere_body :
801+ chat_request ["priority" ] = cohere_body ["priority" ]
778802
779803 # Handle streaming for both versions
780804 if "stream" in endpoint or cohere_body .get ("stream" ):
@@ -988,6 +1012,8 @@ def transform_oci_stream_wrapper(
9881012 generation_id = str (uuid .uuid4 ())
9891013 emitted_start = False
9901014 emitted_content_end = False
1015+ current_content_type : typing .Optional [str ] = None
1016+ current_content_index = 0
9911017 final_finish_reason = "COMPLETE"
9921018 final_usage : typing .Optional [typing .Dict [str , typing .Any ]] = None
9931019 full_v1_text = ""
@@ -1000,18 +1026,23 @@ def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes:
10001026 def _emit_v1_event (event : typing .Dict [str , typing .Any ]) -> bytes :
10011027 return json .dumps (event ).encode ("utf-8" ) + b"\n "
10021028
1029+ def _current_v2_content_type (oci_event : typing .Dict [str , typing .Any ]) -> str :
1030+ message = oci_event .get ("message" )
1031+ if isinstance (message , dict ):
1032+ content_list = message .get ("content" )
1033+ if content_list and isinstance (content_list , list ) and len (content_list ) > 0 :
1034+ oci_type = content_list [0 ].get ("type" , "TEXT" ).upper ()
1035+ if oci_type == "THINKING" :
1036+ return "thinking"
1037+ return "text"
1038+
10031039 def _transform_v2_event (oci_event : typing .Dict [str , typing .Any ]) -> typing .Iterator [bytes ]:
1004- nonlocal emitted_start , emitted_content_end , final_finish_reason , final_usage
1040+ nonlocal emitted_start , emitted_content_end , current_content_type , current_content_index
1041+ nonlocal final_finish_reason , final_usage
1042+
1043+ event_content_type = _current_v2_content_type (oci_event )
10051044
10061045 if not emitted_start :
1007- content_type = "text"
1008- message = oci_event .get ("message" )
1009- if isinstance (message , dict ):
1010- content_list = message .get ("content" )
1011- if content_list and isinstance (content_list , list ) and len (content_list ) > 0 :
1012- oci_type = content_list [0 ].get ("type" , "TEXT" ).upper ()
1013- if oci_type == "THINKING" :
1014- content_type = "thinking"
10151046
10161047 yield _emit_v2_event (
10171048 {
@@ -1023,15 +1054,30 @@ def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10231054 yield _emit_v2_event (
10241055 {
10251056 "type" : "content-start" ,
1026- "index" : 0 ,
1027- "delta" : {"message" : {"content" : {"type" : content_type }}},
1057+ "index" : current_content_index ,
1058+ "delta" : {"message" : {"content" : {"type" : event_content_type }}},
10281059 }
10291060 )
10301061 emitted_start = True
1062+ current_content_type = event_content_type
1063+ elif current_content_type != event_content_type :
1064+ yield _emit_v2_event ({"type" : "content-end" , "index" : current_content_index })
1065+ current_content_index += 1
1066+ yield _emit_v2_event (
1067+ {
1068+ "type" : "content-start" ,
1069+ "index" : current_content_index ,
1070+ "delta" : {"message" : {"content" : {"type" : event_content_type }}},
1071+ }
1072+ )
1073+ current_content_type = event_content_type
1074+ emitted_content_end = False
10311075
10321076 for cohere_event in typing .cast (
10331077 typing .List [typing .Dict [str , typing .Any ]], transform_stream_event (endpoint , oci_event , is_v2 = True )
10341078 ):
1079+ if "index" in cohere_event :
1080+ cohere_event = {** cohere_event , "index" : current_content_index }
10351081 if cohere_event ["type" ] == "content-end" :
10361082 emitted_content_end = True
10371083 final_finish_reason = oci_event .get ("finishReason" , final_finish_reason )
@@ -1069,6 +1115,7 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
10691115 yield _emit_v1_event (
10701116 {
10711117 "event_type" : "stream-end" ,
1118+ "finish_reason" : final_v1_finish_reason ,
10721119 "response" : {
10731120 "text" : full_v1_text ,
10741121 "generation_id" : generation_id ,
0 commit comments