-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm.py
More file actions
136 lines (121 loc) · 5.09 KB
/
llm.py
File metadata and controls
136 lines (121 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
LLM abstraction layer with Anthropic Claude implementation.
Compatible with both old (0.x) and new (1.x+) anthropic package versions.
"""
import json
import os
import asyncio
from typing import Any, Dict, Optional
# Try to import anthropic, provide helpful error if not installed
try:
import anthropic
except ImportError:
anthropic = None
class LLM:
"""
LLM interface with Anthropic Claude implementation.
Set ANTHROPIC_API_KEY environment variable to use.
"""
def __init__(self, model: str = "claude-sonnet-4-20250514", temperature: float = 0.7):
self.model = model
self.temperature = temperature
# Use Haiku for simple prompt generation (10x cheaper, 2-3x faster)
self.haiku_model = "claude-3-5-haiku-20241022"
self.api_key = os.environ.get("ANTHROPIC_API_KEY")
if not self.api_key:
raise ValueError(
"ANTHROPIC_API_KEY environment variable not set. "
"Get your API key from https://console.anthropic.com/"
)
if anthropic is None:
raise ImportError(
"anthropic package not installed. Install with: pip install anthropic"
)
# Try to initialize client - handle both old and new API versions
try:
self.client = anthropic.Anthropic(api_key=self.api_key)
self.api_version = "new"
except TypeError as e:
# Old version - try different initialization
if "proxies" in str(e):
try:
# For older versions, just set the key and use module-level client
anthropic.api_key = self.api_key
self.client = anthropic
self.api_version = "old"
except Exception as e2:
raise RuntimeError(
f"Failed to initialize Anthropic client. "
f"Your anthropic version ({anthropic.__version__}) may be incompatible. "
f"Try: pip install --upgrade anthropic"
)
else:
raise
def generate(self, system: str, user: str, response_format: str = "json", model: str = None, max_tokens: int = 4096) -> str:
"""
Generate a completion given system and user prompts.
Args:
system: System prompt (role/instructions)
user: User prompt (the actual request)
response_format: "json" or "text"
model: Optional model override (defaults to self.model)
max_tokens: Maximum tokens to generate (default 4096)
Returns:
Raw string response (JSON string if response_format="json")
"""
try:
if self.api_version == "new":
# New API (1.x+)
response = self.client.messages.create(
model=model or self.model,
system=system,
messages=[{"role": "user", "content": user}],
temperature=self.temperature,
max_tokens=max_tokens
)
# Extract text from response
text = response.content[0].text
else:
# Old API (0.x)
# For very old versions, recommend upgrading
raise RuntimeError(
f"Your anthropic version ({anthropic.__version__}) is too old. "
f"Please upgrade: pip install --upgrade anthropic\n"
f"Required version: 0.39.0 or higher"
)
# If expecting JSON, try to clean up common formatting issues
if response_format == "json":
# Remove markdown code blocks if present
text = text.strip()
if text.startswith("```json"):
text = text[7:]
if text.startswith("```"):
text = text[3:]
if text.endswith("```"):
text = text[:-3]
text = text.strip()
return text
except Exception as e:
raise RuntimeError(f"LLM API call failed: {str(e)}")
def generate_json(self, system: str, user: str) -> Dict[str, Any]:
"""
Generate and parse JSON response.
"""
response = self.generate(system, user, response_format="json")
try:
return json.loads(response)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Failed to parse JSON response from LLM. Error: {str(e)}\n"
f"Raw response: {response[:500]}..."
)
async def generate_async(self, system: str, user: str, response_format: str = "json", model: str = None, max_tokens: int = 4096) -> str:
"""
Async version of generate() for parallel API calls.
Runs the synchronous call in a thread pool to avoid blocking.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.generate(system, user, response_format, model, max_tokens)
)