66from sentry_sdk .ai .monitoring import record_token_usage
77from sentry_sdk .ai .utils import get_start_span_function , set_data_normalized
88from sentry_sdk .consts import SPANDATA
9+ from sentry_sdk .traces import StreamedSpan
10+ from sentry_sdk .tracing_utils import has_span_streaming_enabled
911
1012if TYPE_CHECKING :
11- from typing import Any , Callable , Iterator
13+ from typing import Any , Callable , Iterator , Union
1214
1315 from sentry_sdk .tracing import Span
1416
@@ -90,9 +92,18 @@ def _capture_exception(exc: "Any") -> None:
9092 sentry_sdk .capture_event (event , hint = hint )
9193
9294
95+ def _end_span (span : "Any" ) -> None :
96+ if isinstance (span , StreamedSpan ):
97+ span .end ()
98+ else :
99+ span .__exit__ (None , None , None )
100+
101+
93102def _wrap_chat (f : "Callable[..., Any]" , streaming : bool ) -> "Callable[..., Any]" :
94103 def collect_chat_response_fields (
95- span : "Span" , res : "NonStreamedChatResponse" , include_pii : bool
104+ span : "Union[Span, StreamedSpan]" ,
105+ res : "NonStreamedChatResponse" ,
106+ include_pii : bool ,
96107 ) -> None :
97108 if include_pii :
98109 if hasattr (res , "text" ):
@@ -129,6 +140,9 @@ def collect_chat_response_fields(
129140 @wraps (f )
130141 def new_chat (* args : "Any" , ** kwargs : "Any" ) -> "Any" :
131142 integration = sentry_sdk .get_client ().get_integration (CohereIntegration )
143+ is_span_streaming_enabled = has_span_streaming_enabled (
144+ sentry_sdk .get_client ().options
145+ )
132146
133147 if (
134148 integration is None
@@ -139,19 +153,29 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
139153
140154 message = kwargs .get ("message" )
141155
142- span = get_start_span_function ()(
143- op = consts .OP .COHERE_CHAT_COMPLETIONS_CREATE ,
144- name = "cohere.client.Chat" ,
145- origin = CohereIntegration .origin ,
146- )
147- span .__enter__ ()
156+ if is_span_streaming_enabled :
157+ span = sentry_sdk .traces .start_span (
158+ name = "cohere.client.Chat" ,
159+ attributes = {
160+ "sentry.op" : consts .OP .COHERE_CHAT_COMPLETIONS_CREATE ,
161+ "sentry.origin" : CohereIntegration .origin ,
162+ },
163+ )
164+ else :
165+ span = get_start_span_function ()(
166+ op = consts .OP .COHERE_CHAT_COMPLETIONS_CREATE ,
167+ name = "cohere.client.Chat" ,
168+ origin = CohereIntegration .origin ,
169+ )
170+ span .__enter__ ()
148171 try :
149172 res = f (* args , ** kwargs )
150173 except Exception as e :
151174 exc_info = sys .exc_info ()
152175 with capture_internal_exceptions ():
153176 _capture_exception (e )
154177 span .__exit__ (* exc_info )
178+
155179 reraise (* exc_info )
156180
157181 with capture_internal_exceptions ():
@@ -195,8 +219,7 @@ def new_iterator() -> "Iterator[StreamedChatResponse]":
195219 and integration .include_prompts ,
196220 )
197221 yield x
198-
199- span .__exit__ (None , None , None )
222+ _end_span (span )
200223
201224 return new_iterator ()
202225 elif isinstance (res , NonStreamedChatResponse ):
@@ -206,11 +229,11 @@ def new_iterator() -> "Iterator[StreamedChatResponse]":
206229 include_pii = should_send_default_pii ()
207230 and integration .include_prompts ,
208231 )
209- span . __exit__ ( None , None , None )
232+ _end_span ( span )
210233 else :
211234 set_data_normalized (span , "unknown_response" , True )
212- span . __exit__ ( None , None , None )
213- return res
235+ _end_span ( span )
236+ return res
214237
215238 return new_chat
216239
@@ -222,11 +245,26 @@ def new_embed(*args: "Any", **kwargs: "Any") -> "Any":
222245 if integration is None :
223246 return f (* args , ** kwargs )
224247
225- with get_start_span_function ()(
226- op = consts .OP .COHERE_EMBEDDINGS_CREATE ,
227- name = "Cohere Embedding Creation" ,
228- origin = CohereIntegration .origin ,
229- ) as span :
248+ is_span_streaming_enabled = has_span_streaming_enabled (
249+ sentry_sdk .get_client ().options
250+ )
251+
252+ if is_span_streaming_enabled :
253+ span_ctx = sentry_sdk .traces .start_span (
254+ name = "Cohere Embedding Creation" ,
255+ attributes = {
256+ "sentry.op" : consts .OP .COHERE_EMBEDDINGS_CREATE ,
257+ "sentry.origin" : CohereIntegration .origin ,
258+ },
259+ )
260+ else :
261+ span_ctx = get_start_span_function ()(
262+ op = consts .OP .COHERE_EMBEDDINGS_CREATE ,
263+ name = "Cohere Embedding Creation" ,
264+ origin = CohereIntegration .origin ,
265+ )
266+
267+ with span_ctx as span :
230268 if "texts" in kwargs and (
231269 should_send_default_pii () and integration .include_prompts
232270 ):
0 commit comments