Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/anthropic/lib/bedrock/_beta_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class Messages(SyncAPIResource):
create = FirstPartyMessagesAPI.create
count_tokens = FirstPartyMessagesAPI.count_tokens

@cached_property
def with_raw_response(self) -> MessagesWithRawResponse:
Expand All @@ -36,6 +37,7 @@ def with_streaming_response(self) -> MessagesWithStreamingResponse:

class AsyncMessages(AsyncAPIResource):
create = FirstPartyAsyncMessagesAPI.create
count_tokens = FirstPartyAsyncMessagesAPI.count_tokens

@cached_property
def with_raw_response(self) -> AsyncMessagesWithRawResponse:
Expand Down
34 changes: 31 additions & 3 deletions src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
import json
import base64
import logging
import urllib.parse
from typing import Any, Union, Mapping, TypeVar
Expand All @@ -10,7 +12,7 @@

from ... import _exceptions
from ._beta import Beta, AsyncBeta
from ..._types import NOT_GIVEN, Timeout, NotGiven
from ..._types import NOT_GIVEN, Timeout, NotGiven, ResponseT
from ..._utils import is_dict, is_given
from ..._compat import model_copy
from ..._version import __version__
Expand Down Expand Up @@ -61,8 +63,20 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
if options.url.startswith("/v1/messages/batches"):
raise AnthropicError("The Batch API is not supported in Bedrock yet")

if options.url == "/v1/messages/count_tokens":
raise AnthropicError("Token counting is not supported in Bedrock yet")
if options.url in {"/v1/messages/count_tokens", "/v1/messages/count_tokens?beta=true"} and options.method == "post":
if not is_dict(options.json_data):
raise RuntimeError("Expected dictionary json_data for post /v1/messages/count_tokens endpoint")

model = options.json_data.pop("model", None)
model = urllib.parse.quote(str(model), safe=":")
options.url = f"/model/{model}/count-tokens"
options.json_data = {
"input": {
"invokeModel": {
"body": base64.b64encode(json.dumps(options.json_data).encode("utf-8")).decode("ascii"),
}
}
}

return options

Expand Down Expand Up @@ -91,6 +105,20 @@ def _infer_region() -> str:


class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@override
def _process_response_data(
self,
*,
data: object,
cast_to: type[ResponseT],
response: httpx.Response,
) -> ResponseT:
# the Bedrock CountTokens API returns `inputTokens` instead of `input_tokens`
if response.request.url.path.endswith("/count-tokens") and is_dict(data) and "inputTokens" in data:
data = {"input_tokens": data["inputTokens"]}

return super()._process_response_data(data=data, cast_to=cast_to, response=response)

@override
def _make_status_error(
self,
Expand Down
88 changes: 88 additions & 0 deletions tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import json
import base64
import typing as t
import tempfile
from typing import TypedDict, cast
Expand Down Expand Up @@ -166,6 +168,92 @@ def test_application_inference_profile(respx_mock: MockRouter) -> None:
)


@pytest.mark.respx()
def test_count_tokens(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock(
return_value=httpx.Response(200, json={"inputTokens": 42}),
)

count = sync_client.messages.count_tokens(
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

assert count.input_tokens == 42

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens"
)

request_body = json.loads(calls[0].request.read())
inner_body = json.loads(base64.b64decode(request_body["input"]["invokeModel"]["body"]))
assert inner_body == {
"messages": [{"role": "user", "content": "Say hello there!"}],
"anthropic_version": "bedrock-2023-05-31",
}


@pytest.mark.respx()
@pytest.mark.asyncio()
async def test_count_tokens_async(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock(
return_value=httpx.Response(200, json={"inputTokens": 42}),
)

count = await async_client.messages.count_tokens(
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

assert count.input_tokens == 42

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens"
)


@pytest.mark.respx()
def test_count_tokens_beta(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock(
return_value=httpx.Response(200, json={"inputTokens": 42}),
)

count = sync_client.beta.messages.count_tokens(
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

assert count.input_tokens == 42

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens"
)


sync_api_key_client = AnthropicBedrock(
aws_region="us-east-1",
api_key="test-api-key",
Expand Down
Loading