diff --git a/src/anthropic/lib/bedrock/_beta_messages.py b/src/anthropic/lib/bedrock/_beta_messages.py index 332f6fbab..59bafcc3a 100644 --- a/src/anthropic/lib/bedrock/_beta_messages.py +++ b/src/anthropic/lib/bedrock/_beta_messages.py @@ -13,6 +13,7 @@ class Messages(SyncAPIResource): create = FirstPartyMessagesAPI.create + count_tokens = FirstPartyMessagesAPI.count_tokens @cached_property def with_raw_response(self) -> MessagesWithRawResponse: @@ -36,6 +37,7 @@ def with_streaming_response(self) -> MessagesWithStreamingResponse: class AsyncMessages(AsyncAPIResource): create = FirstPartyAsyncMessagesAPI.create + count_tokens = FirstPartyAsyncMessagesAPI.count_tokens @cached_property def with_raw_response(self) -> AsyncMessagesWithRawResponse: diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index cda0690df..bce1aa77e 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -1,6 +1,8 @@ from __future__ import annotations import os +import json +import base64 import logging import urllib.parse from typing import Any, Union, Mapping, TypeVar @@ -10,7 +12,7 @@ from ... import _exceptions from ._beta import Beta, AsyncBeta -from ..._types import NOT_GIVEN, Timeout, NotGiven +from ..._types import NOT_GIVEN, Timeout, NotGiven, ResponseT from ..._utils import is_dict, is_given from ..._compat import model_copy from ..._version import __version__ @@ -61,8 +63,20 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions: if options.url.startswith("/v1/messages/batches"): raise AnthropicError("The Batch API is not supported in Bedrock yet") - if options.url == "/v1/messages/count_tokens": - raise AnthropicError("Token counting is not supported in Bedrock yet") + if options.url in {"/v1/messages/count_tokens", "/v1/messages/count_tokens?beta=true"} and options.method == "post": + if not is_dict(options.json_data): + raise RuntimeError("Expected dictionary json_data for post /v1/messages/count_tokens endpoint") + + model = options.json_data.pop("model", None) + model = urllib.parse.quote(str(model), safe=":") + options.url = f"/model/{model}/count-tokens" + options.json_data = { + "input": { + "invokeModel": { + "body": base64.b64encode(json.dumps(options.json_data).encode("utf-8")).decode("ascii"), + } + } + } return options @@ -91,6 +105,20 @@ def _infer_region() -> str: class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]): + @override + def _process_response_data( + self, + *, + data: object, + cast_to: type[ResponseT], + response: httpx.Response, + ) -> ResponseT: + # the Bedrock CountTokens API returns `inputTokens` instead of `input_tokens` + if response.request.url.path.endswith("/count-tokens") and is_dict(data) and "inputTokens" in data: + data = {"input_tokens": data["inputTokens"]} + + return super()._process_response_data(data=data, cast_to=cast_to, response=response) + @override def _make_status_error( self, diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 6e45c27f7..301316dd3 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -1,4 +1,6 @@ import re +import json +import base64 import typing as t import tempfile from typing import TypedDict, cast @@ -166,6 +168,92 @@ def test_application_inference_profile(respx_mock: MockRouter) -> None: ) +@pytest.mark.respx() +def test_count_tokens(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock( + return_value=httpx.Response(200, json={"inputTokens": 42}), + ) + + count = sync_client.messages.count_tokens( + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + assert count.input_tokens == 42 + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + assert ( + calls[0].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens" + ) + + request_body = json.loads(calls[0].request.read()) + inner_body = json.loads(base64.b64decode(request_body["input"]["invokeModel"]["body"])) + assert inner_body == { + "messages": [{"role": "user", "content": "Say hello there!"}], + "anthropic_version": "bedrock-2023-05-31", + } + + +@pytest.mark.respx() +@pytest.mark.asyncio() +async def test_count_tokens_async(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock( + return_value=httpx.Response(200, json={"inputTokens": 42}), + ) + + count = await async_client.messages.count_tokens( + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + assert count.input_tokens == 42 + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + assert ( + calls[0].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens" + ) + + +@pytest.mark.respx() +def test_count_tokens_beta(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock( + return_value=httpx.Response(200, json={"inputTokens": 42}), + ) + + count = sync_client.beta.messages.count_tokens( + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + assert count.input_tokens == 42 + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + assert ( + calls[0].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens" + ) + + sync_api_key_client = AnthropicBedrock( aws_region="us-east-1", api_key="test-api-key",