@@ -516,6 +516,60 @@ def test_v2_response_tool_calls_conversion(self):
516516 self .assertEqual (result ["message" ]["tool_calls" ][0 ]["id" ], "call_123" )
517517
518518
519+ def test_normalize_model_for_oci (self ):
520+ """Test model name normalization for OCI."""
521+ from cohere .oci_client import normalize_model_for_oci
522+
523+ # Plain model name gets cohere. prefix
524+ self .assertEqual (normalize_model_for_oci ("command-a-03-2025" ), "cohere.command-a-03-2025" )
525+ # Already prefixed passes through
526+ self .assertEqual (normalize_model_for_oci ("cohere.embed-english-v3.0" ), "cohere.embed-english-v3.0" )
527+ # OCID passes through
528+ self .assertEqual (
529+ normalize_model_for_oci ("ocid1.generativeaimodel.oc1.us-chicago-1.abc" ),
530+ "ocid1.generativeaimodel.oc1.us-chicago-1.abc" ,
531+ )
532+
533+ def test_transform_embed_request (self ):
534+ """Test embed request transformation to OCI format."""
535+ from cohere .oci_client import transform_request_to_oci
536+
537+ body = {
538+ "model" : "embed-english-v3.0" ,
539+ "texts" : ["hello" , "world" ],
540+ "input_type" : "search_document" ,
541+ "truncate" : "end" ,
542+ "embedding_types" : ["float" , "int8" ],
543+ }
544+ result = transform_request_to_oci ("embed" , body , "compartment-123" )
545+
546+ self .assertEqual (result ["inputs" ], ["hello" , "world" ])
547+ self .assertEqual (result ["inputType" ], "SEARCH_DOCUMENT" )
548+ self .assertEqual (result ["truncate" ], "END" )
549+ self .assertEqual (result ["embeddingTypes" ], ["FLOAT" , "INT8" ])
550+ self .assertEqual (result ["compartmentId" ], "compartment-123" )
551+ self .assertEqual (result ["servingMode" ]["modelId" ], "cohere.embed-english-v3.0" )
552+
553+ def test_transform_chat_request_optional_params (self ):
554+ """Test chat request transformation includes optional params."""
555+ from cohere .oci_client import transform_request_to_oci
556+
557+ body = {
558+ "model" : "command-a-03-2025" ,
559+ "messages" : [{"role" : "user" , "content" : "Hi" }],
560+ "max_tokens" : 100 ,
561+ "temperature" : 0.7 ,
562+ "stop_sequences" : ["END" ],
563+ "frequency_penalty" : 0.5 ,
564+ }
565+ result = transform_request_to_oci ("chat" , body , "compartment-123" )
566+
567+ chat_req = result ["chatRequest" ]
568+ self .assertEqual (chat_req ["maxTokens" ], 100 )
569+ self .assertEqual (chat_req ["temperature" ], 0.7 )
570+ self .assertEqual (chat_req ["stopSequences" ], ["END" ])
571+ self .assertEqual (chat_req ["frequencyPenalty" ], 0.5 )
572+
519573 def test_get_oci_url_known_endpoints (self ):
520574 """Test URL generation for known endpoints."""
521575 from cohere .oci_client import get_oci_url
0 commit comments