Skip to content

Commit 494103e

Browse files
authored
feat(cohere): Add span streaming support (#6479)
Enable the Cohere integration to use span streaming when configured. Add support for both StreamedSpan and traditional span lifecycle patterns. Fixes PY-2315 Fixes #6013
1 parent f40212b commit 494103e

2 files changed

Lines changed: 286 additions & 95 deletions

File tree

sentry_sdk/integrations/cohere.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from sentry_sdk.ai.monitoring import record_token_usage
77
from sentry_sdk.ai.utils import get_start_span_function, set_data_normalized
88
from sentry_sdk.consts import SPANDATA
9+
from sentry_sdk.traces import StreamedSpan
10+
from sentry_sdk.tracing_utils import has_span_streaming_enabled
911

1012
if TYPE_CHECKING:
11-
from typing import Any, Callable, Iterator
13+
from typing import Any, Callable, Iterator, Union
1214

1315
from sentry_sdk.tracing import Span
1416

@@ -90,9 +92,18 @@ def _capture_exception(exc: "Any") -> None:
9092
sentry_sdk.capture_event(event, hint=hint)
9193

9294

95+
def _end_span(span: "Any") -> None:
96+
if isinstance(span, StreamedSpan):
97+
span.end()
98+
else:
99+
span.__exit__(None, None, None)
100+
101+
93102
def _wrap_chat(f: "Callable[..., Any]", streaming: bool) -> "Callable[..., Any]":
94103
def collect_chat_response_fields(
95-
span: "Span", res: "NonStreamedChatResponse", include_pii: bool
104+
span: "Union[Span, StreamedSpan]",
105+
res: "NonStreamedChatResponse",
106+
include_pii: bool,
96107
) -> None:
97108
if include_pii:
98109
if hasattr(res, "text"):
@@ -129,6 +140,9 @@ def collect_chat_response_fields(
129140
@wraps(f)
130141
def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
131142
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
143+
is_span_streaming_enabled = has_span_streaming_enabled(
144+
sentry_sdk.get_client().options
145+
)
132146

133147
if (
134148
integration is None
@@ -139,19 +153,29 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
139153

140154
message = kwargs.get("message")
141155

142-
span = get_start_span_function()(
143-
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
144-
name="cohere.client.Chat",
145-
origin=CohereIntegration.origin,
146-
)
147-
span.__enter__()
156+
if is_span_streaming_enabled:
157+
span = sentry_sdk.traces.start_span(
158+
name="cohere.client.Chat",
159+
attributes={
160+
"sentry.op": consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
161+
"sentry.origin": CohereIntegration.origin,
162+
},
163+
)
164+
else:
165+
span = get_start_span_function()(
166+
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
167+
name="cohere.client.Chat",
168+
origin=CohereIntegration.origin,
169+
)
170+
span.__enter__()
148171
try:
149172
res = f(*args, **kwargs)
150173
except Exception as e:
151174
exc_info = sys.exc_info()
152175
with capture_internal_exceptions():
153176
_capture_exception(e)
154177
span.__exit__(*exc_info)
178+
155179
reraise(*exc_info)
156180

157181
with capture_internal_exceptions():
@@ -195,8 +219,7 @@ def new_iterator() -> "Iterator[StreamedChatResponse]":
195219
and integration.include_prompts,
196220
)
197221
yield x
198-
199-
span.__exit__(None, None, None)
222+
_end_span(span)
200223

201224
return new_iterator()
202225
elif isinstance(res, NonStreamedChatResponse):
@@ -206,11 +229,11 @@ def new_iterator() -> "Iterator[StreamedChatResponse]":
206229
include_pii=should_send_default_pii()
207230
and integration.include_prompts,
208231
)
209-
span.__exit__(None, None, None)
232+
_end_span(span)
210233
else:
211234
set_data_normalized(span, "unknown_response", True)
212-
span.__exit__(None, None, None)
213-
return res
235+
_end_span(span)
236+
return res
214237

215238
return new_chat
216239

@@ -222,11 +245,26 @@ def new_embed(*args: "Any", **kwargs: "Any") -> "Any":
222245
if integration is None:
223246
return f(*args, **kwargs)
224247

225-
with get_start_span_function()(
226-
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
227-
name="Cohere Embedding Creation",
228-
origin=CohereIntegration.origin,
229-
) as span:
248+
is_span_streaming_enabled = has_span_streaming_enabled(
249+
sentry_sdk.get_client().options
250+
)
251+
252+
if is_span_streaming_enabled:
253+
span_ctx = sentry_sdk.traces.start_span(
254+
name="Cohere Embedding Creation",
255+
attributes={
256+
"sentry.op": consts.OP.COHERE_EMBEDDINGS_CREATE,
257+
"sentry.origin": CohereIntegration.origin,
258+
},
259+
)
260+
else:
261+
span_ctx = get_start_span_function()(
262+
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
263+
name="Cohere Embedding Creation",
264+
origin=CohereIntegration.origin,
265+
)
266+
267+
with span_ctx as span:
230268
if "texts" in kwargs and (
231269
should_send_default_pii() and integration.include_prompts
232270
):

0 commit comments

Comments
 (0)