Skip to content

Commit 7c1bf89

Browse files
committed
Use Anthropic for vision tool.
1 parent f3ba4fe commit 7c1bf89

4 files changed

Lines changed: 137 additions & 47 deletions

File tree

Lines changed: 81 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,106 @@
1-
"""Vision tool for analyzing images using OpenAI's Vision API."""
2-
1+
from typing import IO, Optional
2+
import os
3+
from pathlib import Path
34
import base64
4-
from typing import Optional
5+
import tempfile
56
import requests
6-
from openai import OpenAI
7+
import anthropic
78

89
__all__ = ["analyze_image"]
910

11+
PROMPT = os.getenv('VISION_PROMPT', "What's in this image?")
12+
MODEL = os.getenv('VISION_MODEL', "claude-3-5-sonnet-20241022")
13+
MAX_TOKENS = os.getenv('VISION_MAX_TOKENS', 1024)
1014

11-
def analyze_image(image_path_url: str) -> str:
12-
"""
13-
Analyze an image using OpenAI's Vision API.
15+
MEDIA_TYPES = {
16+
"jpg": "image/jpeg",
17+
"jpeg": "image/jpeg",
18+
"png": "image/png",
19+
"gif": "image/gif",
20+
"webp": "image/webp",
21+
}
1422

15-
Args:
16-
image_path_url: Local path or URL to the image
23+
# image sizes that will not be resized
24+
# TODO is there any value in resizing pre-upload?
25+
# 1:1 1092x1092 px
26+
# 3:4 951x1268 px
27+
# 2:3 896x1344 px
28+
# 9:16 819x1456 px
29+
# 1:2 784x1568 px
1730

18-
Returns:
19-
str: Description of the image contents
20-
"""
21-
client = OpenAI()
2231

23-
if not image_path_url:
24-
return "Image Path or URL is required."
32+
def _get_media_type(image_filename: str) -> Optional[str]:
33+
"""Get the media type from an image filename."""
34+
for ext, media_type in MEDIA_TYPES.items():
35+
if image_filename.endswith(ext):
36+
return media_type
37+
return None
38+
2539

26-
if "http" in image_path_url:
27-
return _analyze_web_image(client, image_path_url)
28-
return _analyze_local_image(client, image_path_url)
40+
def _encode_image(image_handle: IO) -> str:
41+
"""Encode a file handle to base64."""
42+
return base64.b64encode(image_handle.read()).decode("utf-8")
2943

3044

31-
def _analyze_web_image(client: OpenAI, image_path_url: str) -> str:
32-
response = client.chat.completions.create(
33-
model="gpt-4-vision-preview",
45+
def _make_anthropic_request(image_handle: IO, media_type: str) -> dict:
46+
"""Make a request to the Anthropic API using an image."""
47+
client = anthropic.Anthropic()
48+
data = _encode_image(image_handle)
49+
return client.messages.create(
50+
model=MODEL,
51+
max_tokens=MAX_TOKENS,
3452
messages=[
3553
{
3654
"role": "user",
3755
"content": [
38-
{"type": "text", "text": "What's in this image?"},
39-
{"type": "image_url", "image_url": {"url": image_path_url}},
56+
{
57+
"type": "image",
58+
"source": {
59+
"type": "base64",
60+
"media_type": media_type,
61+
"data": data,
62+
},
63+
},
64+
{
65+
"type": "text",
66+
"text": PROMPT,
67+
},
4068
],
4169
}
4270
],
43-
max_tokens=300,
4471
)
45-
return response.choices[0].message.content # type: ignore[return-value]
4672

4773

48-
def _analyze_local_image(client: OpenAI, image_path: str) -> str:
49-
base64_image = _encode_image(image_path)
50-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {client.api_key}"}
51-
payload = {
52-
"model": "gpt-4-vision-preview",
53-
"messages": [
54-
{
55-
"role": "user",
56-
"content": [
57-
{"type": "text", "text": "What's in this image?"},
58-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
59-
],
60-
}
61-
],
62-
"max_tokens": 300,
63-
}
64-
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
65-
return response.json()["choices"][0]["message"]["content"]
74+
def _analyze_web_image(image_url: str) -> str:
75+
"""Analyze an image from a URL."""
76+
with tempfile.NamedTemporaryFile() as temp_file:
77+
temp_file.write(requests.get(image_url).content)
78+
temp_file.flush()
79+
temp_file.seek(0)
80+
response = _make_anthropic_request(temp_file, _get_media_type(image_url))
81+
return response.content[0].text
6682

6783

68-
def _encode_image(image_path: str) -> str:
84+
def _analyze_local_image(image_path: str) -> str:
85+
"""Analyze an image from a local file."""
6986
with open(image_path, "rb") as image_file:
70-
return base64.b64encode(image_file.read()).decode("utf-8")
87+
response = _make_anthropic_request(image_file, _get_media_type(image_path))
88+
return response.content[0].text
89+
90+
91+
def analyze_image(image_path_or_url: str) -> str:
92+
"""
93+
Analyze an image using OpenAI's Vision API.
94+
95+
Args:
96+
image_path_or_url: Local path or URL to the image.
97+
98+
Returns:
99+
str: Description of the image contents
100+
"""
101+
if not image_path_or_url:
102+
return "Image Path or URL is required."
103+
104+
if "http" in image_path_or_url:
105+
return _analyze_web_image(image_path_or_url)
106+
return _analyze_local_image(image_path_or_url)

agentstack/_tools/vision/config.json

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
"name": "vision",
33
"category": "image-analysis",
44
"env": {
5-
"OPENAI_API_KEY": null
5+
"ANTHROPIC_API_KEY": null,
6+
"VISION_PROMPT": null,
7+
"VISION_MODEL": null,
8+
"VISION_MAX_TOKENS": null
69
},
710
"dependencies": [
8-
"openai>=1.0.0",
11+
"anthropic>=0.45.2",
912
"requests>=2.31.0"
1013
],
1114
"tools": ["analyze_image"]

tests/fixtures/test_image.jpg

35.5 KB
Loading

tests/tools/test_tool_vision.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
from pathlib import Path
3+
import unittest
4+
from agentstack._tools import ToolConfig
5+
6+
7+
TEST_IMAGE_PATH: Path = Path(__file__).parent.parent / 'fixtures/test_image.jpg'
8+
9+
10+
class VisionToolTest(unittest.TestCase):
11+
def setUp(self):
12+
tool = ToolConfig.from_tool_name('vision')
13+
for dependency in tool.dependencies:
14+
os.system(f"pip install {dependency}")
15+
16+
def test_get_media_type(self):
17+
from agentstack._tools.vision import _get_media_type
18+
19+
self.assertEqual(_get_media_type("image.jpg"), "image/jpeg")
20+
self.assertEqual(_get_media_type("image.jpeg"), "image/jpeg")
21+
self.assertEqual(_get_media_type("http://google.com/image.png"), "image/png")
22+
self.assertEqual(_get_media_type("/foo/bar/image.gif"), "image/gif")
23+
self.assertEqual(_get_media_type("image.webp"), "image/webp")
24+
self.assertEqual(_get_media_type("document.pdf"), None)
25+
26+
def test_encode_image(self):
27+
from agentstack._tools.vision import _encode_image
28+
29+
with open(TEST_IMAGE_PATH, "rb") as image_file:
30+
encoded_image = _encode_image(image_file)
31+
print(encoded_image[:200])
32+
self.assertTrue(isinstance(encoded_image, str))
33+
34+
def test_analyze_image_web_live(self):
35+
from agentstack._tools.vision import analyze_image
36+
37+
if not os.environ.get('ANTHROPIC_API_KEY'):
38+
self.skipTest("ANTHROPIC_API_KEY not set")
39+
40+
image_url = "https://upload.wikimedia.org/wikipedia/en/f/f7/RickRoll.png"
41+
result = analyze_image(image_url)
42+
self.assertTrue(isinstance(result, str))
43+
44+
def test_analyze_image_local_live(self):
45+
from agentstack._tools.vision import analyze_image
46+
47+
if not os.environ.get('ANTHROPIC_API_KEY'):
48+
self.skipTest("ANTHROPIC_API_KEY not set")
49+
50+
result = analyze_image(str(TEST_IMAGE_PATH))
51+
self.assertTrue(isinstance(result, str))

0 commit comments

Comments
 (0)