Skip to content

Commit ec237b4

Browse files
committed
move to separate folder
1 parent 54e99fe commit ec237b4

3 files changed

Lines changed: 170 additions & 120 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import sys
2+
from functools import wraps
3+
4+
from sentry_sdk.ai.monitoring import record_token_usage
5+
from sentry_sdk.consts import OP, SPANDATA
6+
from sentry_sdk.ai.utils import set_data_normalized
7+
8+
from typing import TYPE_CHECKING
9+
10+
from sentry_sdk.tracing_utils import set_span_errored
11+
12+
if TYPE_CHECKING:
13+
from typing import Any, Callable
14+
15+
import sentry_sdk
16+
from sentry_sdk.scope import should_send_default_pii
17+
from sentry_sdk.integrations import DidNotEnable, Integration
18+
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise
19+
20+
try:
21+
from cohere import __version__ as cohere_version # noqa: F401
22+
except ImportError:
23+
raise DidNotEnable("Cohere not installed")
24+
25+
COLLECTED_CHAT_PARAMS = {
26+
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
27+
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
28+
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
29+
"k": SPANDATA.GEN_AI_REQUEST_TOP_K,
30+
"p": SPANDATA.GEN_AI_REQUEST_TOP_P,
31+
"seed": SPANDATA.GEN_AI_REQUEST_SEED,
32+
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
33+
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
34+
}
35+
36+
37+
class CohereIntegration(Integration):
38+
identifier = "cohere"
39+
origin = f"auto.ai.{identifier}"
40+
41+
def __init__(self, include_prompts=True):
42+
# type: (bool) -> None
43+
self.include_prompts = include_prompts
44+
45+
@staticmethod
46+
def setup_once():
47+
# type: () -> None
48+
# Lazy imports to avoid circular dependencies:
49+
# v1/v2 import COLLECTED_CHAT_PARAMS and _capture_exception from this module.
50+
from sentry_sdk.integrations.cohere.v1 import setup_v1
51+
from sentry_sdk.integrations.cohere.v2 import setup_v2
52+
53+
setup_v1(_wrap_embed)
54+
setup_v2(_wrap_embed)
55+
56+
57+
def _capture_exception(exc):
58+
# type: (Any) -> None
59+
set_span_errored()
60+
61+
event, hint = event_from_exception(
62+
exc,
63+
client_options=sentry_sdk.get_client().options,
64+
mechanism={"type": "cohere", "handled": False},
65+
)
66+
sentry_sdk.capture_event(event, hint=hint)
67+
68+
69+
def _wrap_embed(f):
70+
# type: (Callable[..., Any]) -> Callable[..., Any]
71+
@wraps(f)
72+
def new_embed(*args, **kwargs):
73+
# type: (*Any, **Any) -> Any
74+
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
75+
if integration is None:
76+
return f(*args, **kwargs)
77+
78+
model = kwargs.get("model", "")
79+
80+
with sentry_sdk.start_span(
81+
op=OP.GEN_AI_EMBEDDINGS,
82+
name=f"embeddings {model}".strip(),
83+
origin=CohereIntegration.origin,
84+
) as span:
85+
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
86+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
87+
88+
if "texts" in kwargs and (
89+
should_send_default_pii() and integration.include_prompts
90+
):
91+
if isinstance(kwargs["texts"], str):
92+
set_data_normalized(
93+
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, [kwargs["texts"]]
94+
)
95+
elif (
96+
isinstance(kwargs["texts"], list)
97+
and len(kwargs["texts"]) > 0
98+
and isinstance(kwargs["texts"][0], str)
99+
):
100+
set_data_normalized(
101+
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, kwargs["texts"]
102+
)
103+
104+
if "model" in kwargs:
105+
set_data_normalized(
106+
span, SPANDATA.GEN_AI_REQUEST_MODEL, kwargs["model"]
107+
)
108+
try:
109+
res = f(*args, **kwargs)
110+
except Exception as e:
111+
exc_info = sys.exc_info()
112+
with capture_internal_exceptions():
113+
_capture_exception(e)
114+
reraise(*exc_info)
115+
if (
116+
hasattr(res, "meta")
117+
and hasattr(res.meta, "billed_units")
118+
and hasattr(res.meta.billed_units, "input_tokens")
119+
):
120+
record_token_usage(
121+
span,
122+
input_tokens=res.meta.billed_units.input_tokens,
123+
total_tokens=res.meta.billed_units.input_tokens,
124+
)
125+
return res
126+
127+
return new_embed
Lines changed: 42 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,19 @@
1111

1212
from typing import TYPE_CHECKING
1313

14-
from sentry_sdk.tracing_utils import set_span_errored
15-
1614
if TYPE_CHECKING:
1715
from typing import Any, Callable, Iterator
1816
from sentry_sdk.tracing import Span
1917

2018
import sentry_sdk
2119
from sentry_sdk.scope import should_send_default_pii
22-
from sentry_sdk.integrations import DidNotEnable, Integration
23-
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise
24-
from sentry_sdk.integrations.cohere_v2 import setup_v2
25-
26-
try:
27-
from cohere.client import Client
28-
from cohere.base_client import BaseCohere
29-
from cohere import (
30-
ChatStreamEndEvent,
31-
NonStreamedChatResponse,
32-
)
33-
34-
if TYPE_CHECKING:
35-
from cohere import StreamedChatResponse
36-
except ImportError:
37-
raise DidNotEnable("Cohere not installed")
20+
from sentry_sdk.utils import capture_internal_exceptions, reraise
3821

39-
try:
40-
# cohere 5.9.3+
41-
from cohere import StreamEndStreamedChatResponse
42-
except ImportError:
43-
from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse
44-
45-
COLLECTED_CHAT_PARAMS = {
46-
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
47-
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
48-
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
49-
"k": SPANDATA.GEN_AI_REQUEST_TOP_K,
50-
"p": SPANDATA.GEN_AI_REQUEST_TOP_P,
51-
"seed": SPANDATA.GEN_AI_REQUEST_SEED,
52-
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
53-
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
54-
}
22+
from sentry_sdk.integrations.cohere import (
23+
CohereIntegration,
24+
COLLECTED_CHAT_PARAMS,
25+
_capture_exception,
26+
)
5527

5628
COLLECTED_PII_CHAT_PARAMS = {
5729
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
@@ -68,36 +40,44 @@
6840
}
6941

7042

71-
class CohereIntegration(Integration):
72-
identifier = "cohere"
73-
origin = f"auto.ai.{identifier}"
74-
75-
def __init__(self: "CohereIntegration", include_prompts: bool = True) -> None:
76-
self.include_prompts = include_prompts
43+
def setup_v1(wrap_embed_fn):
44+
# type: (Callable[..., Any]) -> None
45+
"""Called from CohereIntegration.setup_once() to patch V1 Client methods."""
46+
try:
47+
from cohere.client import Client
48+
from cohere.base_client import BaseCohere
49+
except ImportError:
50+
return
7751

78-
@staticmethod
79-
def setup_once() -> None:
80-
BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False)
81-
Client.embed = _wrap_embed(Client.embed)
82-
BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True)
83-
setup_v2(_wrap_embed)
52+
BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False)
53+
BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True)
54+
Client.embed = wrap_embed_fn(Client.embed)
8455

8556

86-
def _capture_exception(exc: "Any") -> None:
87-
set_span_errored()
57+
def _wrap_chat(f, streaming):
58+
# type: (Callable[..., Any], bool) -> Callable[..., Any]
8859

89-
event, hint = event_from_exception(
90-
exc,
91-
client_options=sentry_sdk.get_client().options,
92-
mechanism={"type": "cohere", "handled": False},
93-
)
94-
sentry_sdk.capture_event(event, hint=hint)
60+
try:
61+
from cohere import (
62+
ChatStreamEndEvent,
63+
NonStreamedChatResponse,
64+
)
9565

66+
if TYPE_CHECKING:
67+
from cohere import StreamedChatResponse
68+
except ImportError:
69+
return f
70+
71+
try:
72+
# cohere 5.9.3+
73+
from cohere import StreamEndStreamedChatResponse
74+
except ImportError:
75+
from cohere import (
76+
StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse,
77+
)
9678

97-
def _wrap_chat(f: "Callable[..., Any]", streaming: bool) -> "Callable[..., Any]":
98-
def collect_chat_response_fields(
99-
span: "Span", res: "NonStreamedChatResponse", include_pii: bool
100-
) -> None:
79+
def collect_chat_response_fields(span, res, include_pii):
80+
# type: (Span, NonStreamedChatResponse, bool) -> None
10181
if include_pii:
10282
if hasattr(res, "text"):
10383
set_data_normalized(
@@ -128,7 +108,8 @@ def collect_chat_response_fields(
128108
)
129109

130110
@wraps(f)
131-
def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
111+
def new_chat(*args, **kwargs):
112+
# type: (*Any, **Any) -> Any
132113
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
133114

134115
if (
@@ -194,7 +175,8 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
194175
if streaming:
195176
old_iterator = res
196177

197-
def new_iterator() -> "Iterator[StreamedChatResponse]":
178+
def new_iterator():
179+
# type: () -> Iterator[StreamedChatResponse]
198180
with capture_internal_exceptions():
199181
for x in old_iterator:
200182
if isinstance(x, ChatStreamEndEvent) or isinstance(
@@ -225,62 +207,3 @@ def new_iterator() -> "Iterator[StreamedChatResponse]":
225207
return res
226208

227209
return new_chat
228-
229-
230-
def _wrap_embed(f: "Callable[..., Any]") -> "Callable[..., Any]":
231-
@wraps(f)
232-
def new_embed(*args: "Any", **kwargs: "Any") -> "Any":
233-
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
234-
if integration is None:
235-
return f(*args, **kwargs)
236-
237-
model = kwargs.get("model", "")
238-
239-
with sentry_sdk.start_span(
240-
op=OP.GEN_AI_EMBEDDINGS,
241-
name=f"embeddings {model}".strip(),
242-
origin=CohereIntegration.origin,
243-
) as span:
244-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
245-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
246-
247-
if "texts" in kwargs and (
248-
should_send_default_pii() and integration.include_prompts
249-
):
250-
if isinstance(kwargs["texts"], str):
251-
set_data_normalized(
252-
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, [kwargs["texts"]]
253-
)
254-
elif (
255-
isinstance(kwargs["texts"], list)
256-
and len(kwargs["texts"]) > 0
257-
and isinstance(kwargs["texts"][0], str)
258-
):
259-
set_data_normalized(
260-
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, kwargs["texts"]
261-
)
262-
263-
if "model" in kwargs:
264-
set_data_normalized(
265-
span, SPANDATA.GEN_AI_REQUEST_MODEL, kwargs["model"]
266-
)
267-
try:
268-
res = f(*args, **kwargs)
269-
except Exception as e:
270-
exc_info = sys.exc_info()
271-
with capture_internal_exceptions():
272-
_capture_exception(e)
273-
reraise(*exc_info)
274-
if (
275-
hasattr(res, "meta")
276-
and hasattr(res.meta, "billed_units")
277-
and hasattr(res.meta.billed_units, "input_tokens")
278-
):
279-
record_token_usage(
280-
span,
281-
input_tokens=res.meta.billed_units.input_tokens,
282-
total_tokens=res.meta.billed_units.input_tokens,
283-
)
284-
return res
285-
286-
return new_embed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def setup_v2(wrap_embed_fn):
5555
# type: (Callable[..., Any]) -> None
5656
"""Called from CohereIntegration.setup_once() to patch V2Client methods.
5757
58-
The embed wrapper is passed in from cohere.py to reuse the same _wrap_embed
58+
The embed wrapper is passed in from __init__.py to reuse the same _wrap_embed
5959
for both V1 and V2, since the embed response format (.meta.billed_units)
6060
is identical across both API versions.
6161
"""

0 commit comments

Comments
 (0)