1515import unittest
1616
1717import cohere
18+ from cohere .errors import NotFoundError
1819
1920
2021@unittest .skipIf (os .getenv ("TEST_OCI" ) is None , "TEST_OCI not set" )
@@ -467,11 +468,14 @@ def test_embed_english_v3(self):
467468
468469 def test_embed_light_v3 (self ):
469470 """Test embed-english-light-v3.0 model."""
470- response = self .client .embed (
471- model = "embed-english-light-v3.0" ,
472- texts = ["Test" ],
473- input_type = "search_document" ,
474- )
471+ try :
472+ response = self .client .embed (
473+ model = "embed-english-light-v3.0" ,
474+ texts = ["Test" ],
475+ input_type = "search_document" ,
476+ )
477+ except NotFoundError :
478+ self .skipTest ("embed-english-light-v3.0 is not available in this OCI region/profile" )
475479 self .assertIsNotNone (response .embeddings )
476480 self .assertEqual (len (response .embeddings [0 ]), 384 )
477481
@@ -709,7 +713,7 @@ def test_get_oci_url_known_endpoints(self):
709713 url = get_oci_url ("us-chicago-1" , "chat" )
710714 self .assertIn ("/actions/chat" , url )
711715
712- url = get_oci_url ("us-chicago-1" , "chat_stream" , stream = True )
716+ url = get_oci_url ("us-chicago-1" , "chat_stream" )
713717 self .assertIn ("/actions/chat" , url )
714718
715719 def test_get_oci_url_unknown_endpoint_raises (self ):
@@ -744,7 +748,7 @@ def test_stream_wrapper_emits_full_event_lifecycle(self):
744748
745749 chunks = [
746750 b'data: {"message": {"content": [{"type": "TEXT", "text": "Hello"}]}}\n ' ,
747- b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE"}\n ' ,
751+ b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE", "usage": {"inputTokens": 3, "completionTokens": 2} }\n ' ,
748752 b"data: [DONE]\n " ,
749753 ]
750754
@@ -758,13 +762,19 @@ def test_stream_wrapper_emits_full_event_lifecycle(self):
758762 self .assertEqual (event_types [0 ], "message-start" )
759763 self .assertEqual (event_types [1 ], "content-start" )
760764 self .assertEqual (event_types [2 ], "content-delta" )
761- self .assertEqual (event_types [3 ], "content-end" )
762- self .assertEqual (event_types [4 ], "message-end" )
765+ self .assertEqual (event_types [3 ], "content-delta" )
766+ self .assertEqual (event_types [4 ], "content-end" )
767+ self .assertEqual (event_types [5 ], "message-end" )
763768
764769 self .assertIn ("id" , events [0 ])
765770 self .assertEqual (events [0 ]["delta" ]["message" ]["role" ], "assistant" )
766771 self .assertEqual (events [1 ]["index" ], 0 )
767772 self .assertEqual (events [1 ]["delta" ]["message" ]["content" ]["type" ], "text" )
773+ self .assertEqual (events [2 ]["delta" ]["message" ]["content" ]["text" ], "Hello" )
774+ self .assertEqual (events [3 ]["delta" ]["message" ]["content" ]["text" ], " world" )
775+ self .assertEqual (events [5 ]["delta" ]["finish_reason" ], "COMPLETE" )
776+ self .assertEqual (events [5 ]["delta" ]["usage" ]["tokens" ]["input_tokens" ], 3 )
777+ self .assertEqual (events [5 ]["delta" ]["usage" ]["tokens" ]["output_tokens" ], 2 )
768778
769779 def test_stream_wrapper_skips_malformed_json_with_warning (self ):
770780 """Test that malformed JSON in SSE streams is skipped with a warning."""
@@ -779,6 +789,60 @@ def test_stream_wrapper_skips_malformed_json_with_warning(self):
779789 events = list (transform_oci_stream_wrapper (iter (chunks ), "chat" , is_v2 = True ))
780790 self .assertEqual (len (events ), 4 )
781791
792+ def test_v1_stream_wrapper_emits_stream_end (self ):
793+ """Test that V1 chat streams end with a stream-end event containing the response payload."""
794+ import json
795+ from cohere .oci_client import transform_oci_stream_wrapper
796+
797+ chunks = [
798+ b'data: {"text": "Hello", "isFinished": false}\n ' ,
799+ b'data: {"text": " world", "isFinished": true, "finishReason": "COMPLETE"}\n ' ,
800+ b"data: [DONE]\n " ,
801+ ]
802+
803+ events = [
804+ json .loads (raw .decode ("utf-8" ))
805+ for raw in transform_oci_stream_wrapper (iter (chunks ), "chat_stream" , is_v2 = False )
806+ ]
807+
808+ self .assertEqual (events [0 ]["event_type" ], "text-generation" )
809+ self .assertEqual (events [1 ]["event_type" ], "text-generation" )
810+ self .assertEqual (events [2 ]["event_type" ], "stream-end" )
811+ self .assertEqual (events [2 ]["finish_reason" ], "COMPLETE" )
812+ self .assertEqual (events [2 ]["response" ]["text" ], "Hello world" )
813+
814+ def test_session_auth_expands_key_file_path (self ):
815+ """Test that session-based auth expands key_file paths before loading them."""
816+ from unittest .mock import MagicMock , patch
817+ from cohere .oci_client import map_request_to_oci
818+
819+ mock_private_key = object ()
820+ mock_security_token_signer = MagicMock ()
821+ mock_oci = MagicMock ()
822+ mock_oci .signer .load_private_key_from_file .return_value = mock_private_key
823+ mock_oci .auth .signers .SecurityTokenSigner .return_value = mock_security_token_signer
824+
825+ oci_config = {
826+ "security_token_file" : "~/.oci/sessions/TEST/token" ,
827+ "key_file" : "~/.oci/sessions/TEST/oci_api_key.pem" ,
828+ }
829+
830+ with patch ("cohere.oci_client.lazy_oci" , return_value = mock_oci ), patch (
831+ "builtins.open" , create = True
832+ ) as mock_open :
833+ mock_open .return_value .__enter__ .return_value .read .return_value = "token"
834+ with patch ("os.path.expanduser" , side_effect = lambda p : p .replace ("~" , "/Users/test" )):
835+ with patch ("cohere.oci_client.transform_request_to_oci" , return_value = {"compartmentId" : "c" }):
836+ hook = map_request_to_oci (oci_config , "us-chicago-1" , "compartment-123" , is_v2_client = False )
837+ request = MagicMock ()
838+ request .url .path = "/v1/chat"
839+ request .read .return_value = b'{"message":"hi"}'
840+ request .method = "POST"
841+ request .extensions = {}
842+ hook (request )
843+
844+ mock_oci .signer .load_private_key_from_file .assert_called_with ("/Users/test/.oci/sessions/TEST/oci_api_key.pem" )
845+
782846 def test_stream_wrapper_raises_on_transform_error (self ):
783847 """Test that stream transformation errors are re-raised with OCI-specific context."""
784848 from cohere .oci_client import transform_oci_stream_wrapper
0 commit comments