|
11 | 11 |
|
12 | 12 | from typing import TYPE_CHECKING |
13 | 13 |
|
14 | | -from sentry_sdk.tracing_utils import set_span_errored |
15 | | - |
16 | 14 | if TYPE_CHECKING: |
17 | 15 | from typing import Any, Callable, Iterator |
18 | 16 | from sentry_sdk.tracing import Span |
19 | 17 |
|
20 | 18 | import sentry_sdk |
21 | 19 | from sentry_sdk.scope import should_send_default_pii |
22 | | -from sentry_sdk.integrations import DidNotEnable, Integration |
23 | | -from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise |
24 | | -from sentry_sdk.integrations.cohere_v2 import setup_v2 |
25 | | - |
26 | | -try: |
27 | | - from cohere.client import Client |
28 | | - from cohere.base_client import BaseCohere |
29 | | - from cohere import ( |
30 | | - ChatStreamEndEvent, |
31 | | - NonStreamedChatResponse, |
32 | | - ) |
33 | | - |
34 | | - if TYPE_CHECKING: |
35 | | - from cohere import StreamedChatResponse |
36 | | -except ImportError: |
37 | | - raise DidNotEnable("Cohere not installed") |
| 20 | +from sentry_sdk.utils import capture_internal_exceptions, reraise |
38 | 21 |
|
39 | | -try: |
40 | | - # cohere 5.9.3+ |
41 | | - from cohere import StreamEndStreamedChatResponse |
42 | | -except ImportError: |
43 | | - from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse |
44 | | - |
45 | | -COLLECTED_CHAT_PARAMS = { |
46 | | - "model": SPANDATA.GEN_AI_REQUEST_MODEL, |
47 | | - "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE, |
48 | | - "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, |
49 | | - "k": SPANDATA.GEN_AI_REQUEST_TOP_K, |
50 | | - "p": SPANDATA.GEN_AI_REQUEST_TOP_P, |
51 | | - "seed": SPANDATA.GEN_AI_REQUEST_SEED, |
52 | | - "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, |
53 | | - "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, |
54 | | -} |
| 22 | +from sentry_sdk.integrations.cohere import ( |
| 23 | + CohereIntegration, |
| 24 | + COLLECTED_CHAT_PARAMS, |
| 25 | + _capture_exception, |
| 26 | +) |
55 | 27 |
|
56 | 28 | COLLECTED_PII_CHAT_PARAMS = { |
57 | 29 | "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, |
|
68 | 40 | } |
69 | 41 |
|
70 | 42 |
|
71 | | -class CohereIntegration(Integration): |
72 | | - identifier = "cohere" |
73 | | - origin = f"auto.ai.{identifier}" |
74 | | - |
75 | | - def __init__(self: "CohereIntegration", include_prompts: bool = True) -> None: |
76 | | - self.include_prompts = include_prompts |
| 43 | +def setup_v1(wrap_embed_fn): |
| 44 | + # type: (Callable[..., Any]) -> None |
| 45 | + """Called from CohereIntegration.setup_once() to patch V1 Client methods.""" |
| 46 | + try: |
| 47 | + from cohere.client import Client |
| 48 | + from cohere.base_client import BaseCohere |
| 49 | + except ImportError: |
| 50 | + return |
77 | 51 |
|
78 | | - @staticmethod |
79 | | - def setup_once() -> None: |
80 | | - BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) |
81 | | - Client.embed = _wrap_embed(Client.embed) |
82 | | - BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) |
83 | | - setup_v2(_wrap_embed) |
| 52 | + BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) |
| 53 | + BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) |
| 54 | + Client.embed = wrap_embed_fn(Client.embed) |
84 | 55 |
|
85 | 56 |
|
86 | | -def _capture_exception(exc: "Any") -> None: |
87 | | - set_span_errored() |
| 57 | +def _wrap_chat(f, streaming): |
| 58 | + # type: (Callable[..., Any], bool) -> Callable[..., Any] |
88 | 59 |
|
89 | | - event, hint = event_from_exception( |
90 | | - exc, |
91 | | - client_options=sentry_sdk.get_client().options, |
92 | | - mechanism={"type": "cohere", "handled": False}, |
93 | | - ) |
94 | | - sentry_sdk.capture_event(event, hint=hint) |
| 60 | + try: |
| 61 | + from cohere import ( |
| 62 | + ChatStreamEndEvent, |
| 63 | + NonStreamedChatResponse, |
| 64 | + ) |
95 | 65 |
|
| 66 | + if TYPE_CHECKING: |
| 67 | + from cohere import StreamedChatResponse |
| 68 | + except ImportError: |
| 69 | + return f |
| 70 | + |
| 71 | + try: |
| 72 | + # cohere 5.9.3+ |
| 73 | + from cohere import StreamEndStreamedChatResponse |
| 74 | + except ImportError: |
| 75 | + from cohere import ( |
| 76 | + StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse, |
| 77 | + ) |
96 | 78 |
|
97 | | -def _wrap_chat(f: "Callable[..., Any]", streaming: bool) -> "Callable[..., Any]": |
98 | | - def collect_chat_response_fields( |
99 | | - span: "Span", res: "NonStreamedChatResponse", include_pii: bool |
100 | | - ) -> None: |
| 79 | + def collect_chat_response_fields(span, res, include_pii): |
| 80 | + # type: (Span, NonStreamedChatResponse, bool) -> None |
101 | 81 | if include_pii: |
102 | 82 | if hasattr(res, "text"): |
103 | 83 | set_data_normalized( |
@@ -128,7 +108,8 @@ def collect_chat_response_fields( |
128 | 108 | ) |
129 | 109 |
|
130 | 110 | @wraps(f) |
131 | | - def new_chat(*args: "Any", **kwargs: "Any") -> "Any": |
| 111 | + def new_chat(*args, **kwargs): |
| 112 | + # type: (*Any, **Any) -> Any |
132 | 113 | integration = sentry_sdk.get_client().get_integration(CohereIntegration) |
133 | 114 |
|
134 | 115 | if ( |
@@ -194,7 +175,8 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any": |
194 | 175 | if streaming: |
195 | 176 | old_iterator = res |
196 | 177 |
|
197 | | - def new_iterator() -> "Iterator[StreamedChatResponse]": |
| 178 | + def new_iterator(): |
| 179 | + # type: () -> Iterator[StreamedChatResponse] |
198 | 180 | with capture_internal_exceptions(): |
199 | 181 | for x in old_iterator: |
200 | 182 | if isinstance(x, ChatStreamEndEvent) or isinstance( |
@@ -225,62 +207,3 @@ def new_iterator() -> "Iterator[StreamedChatResponse]": |
225 | 207 | return res |
226 | 208 |
|
227 | 209 | return new_chat |
228 | | - |
229 | | - |
230 | | -def _wrap_embed(f: "Callable[..., Any]") -> "Callable[..., Any]": |
231 | | - @wraps(f) |
232 | | - def new_embed(*args: "Any", **kwargs: "Any") -> "Any": |
233 | | - integration = sentry_sdk.get_client().get_integration(CohereIntegration) |
234 | | - if integration is None: |
235 | | - return f(*args, **kwargs) |
236 | | - |
237 | | - model = kwargs.get("model", "") |
238 | | - |
239 | | - with sentry_sdk.start_span( |
240 | | - op=OP.GEN_AI_EMBEDDINGS, |
241 | | - name=f"embeddings {model}".strip(), |
242 | | - origin=CohereIntegration.origin, |
243 | | - ) as span: |
244 | | - set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere") |
245 | | - set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") |
246 | | - |
247 | | - if "texts" in kwargs and ( |
248 | | - should_send_default_pii() and integration.include_prompts |
249 | | - ): |
250 | | - if isinstance(kwargs["texts"], str): |
251 | | - set_data_normalized( |
252 | | - span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, [kwargs["texts"]] |
253 | | - ) |
254 | | - elif ( |
255 | | - isinstance(kwargs["texts"], list) |
256 | | - and len(kwargs["texts"]) > 0 |
257 | | - and isinstance(kwargs["texts"][0], str) |
258 | | - ): |
259 | | - set_data_normalized( |
260 | | - span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, kwargs["texts"] |
261 | | - ) |
262 | | - |
263 | | - if "model" in kwargs: |
264 | | - set_data_normalized( |
265 | | - span, SPANDATA.GEN_AI_REQUEST_MODEL, kwargs["model"] |
266 | | - ) |
267 | | - try: |
268 | | - res = f(*args, **kwargs) |
269 | | - except Exception as e: |
270 | | - exc_info = sys.exc_info() |
271 | | - with capture_internal_exceptions(): |
272 | | - _capture_exception(e) |
273 | | - reraise(*exc_info) |
274 | | - if ( |
275 | | - hasattr(res, "meta") |
276 | | - and hasattr(res.meta, "billed_units") |
277 | | - and hasattr(res.meta.billed_units, "input_tokens") |
278 | | - ): |
279 | | - record_token_usage( |
280 | | - span, |
281 | | - input_tokens=res.meta.billed_units.input_tokens, |
282 | | - total_tokens=res.meta.billed_units.input_tokens, |
283 | | - ) |
284 | | - return res |
285 | | - |
286 | | - return new_embed |
0 commit comments