-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
126 lines (98 loc) · 3.88 KB
/
model.py
File metadata and controls
126 lines (98 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import math
from dataclasses import dataclass
from typing import Optional, List
from abc import ABC, abstractmethod
from dotenv import load_dotenv
load_dotenv()
@dataclass
class ModelResponse:
text: str
true_confidence: Optional[float] = None
class BaseModel(ABC):
@abstractmethod
def generate(self, system_prompt: str, user_prompt: str, json_mode: bool = False) -> ModelResponse:
pass
def _calc_confidence_from_logprobs(self, logprobs: List[float]) -> float:
if not logprobs:
return None
probs = [math.exp(lp) for lp in logprobs]
avg_prob = sum(probs) / len(probs)
return round(avg_prob, 4)
class MistralWrapper(BaseModel):
def __init__(
self,
model_name: str = "mistral-large-latest",
max_tokens: int = 16384,
api_key: Optional[str] = None
):
from mistralai import Mistral
self.model_name = model_name
self.max_tokens = max_tokens
self.api_key = api_key or os.environ.get("MISTRAL_API_KEY")
if not self.api_key:
raise ValueError(
"MISTRAL_API_KEY not found. Set it in environment or pass api_key parameter."
)
print(f"Initializing Mistral client with model: {model_name}")
self.client = Mistral(api_key=self.api_key)
print("Mistral client ready.")
def generate(self, system_prompt: str, user_prompt: str, json_mode: bool = False) -> ModelResponse:
response_format = {"type": "json_object"} if json_mode else None
chat_response = self.client.chat.complete(
model=self.model_name,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
}
],
max_tokens=self.max_tokens,
temperature=0.7,
response_format=response_format
)
text = chat_response.choices[0].message.content.strip()
return ModelResponse(text=text, true_confidence=None)
class GeminiWrapper(BaseModel):
def __init__(
self,
model_name: str = "gemini-2.5-flash",
api_key: Optional[str] = None
):
from google import genai
self.model_name = model_name
self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"GEMINI_API_KEY not found. Set it in environment or pass api_key parameter."
)
print(f"Initializing Gemini client with model: {model_name}")
self.client = genai.Client(api_key=self.api_key)
print("Gemini client ready.")
def generate(self, system_prompt: str, user_prompt: str, json_mode: bool = False) -> ModelResponse:
from google.genai import types
config_args = {
"system_instruction": system_prompt
}
if json_mode:
config_args["response_mime_type"] = "application/json"
response = self.client.models.generate_content(
model=self.model_name,
config=types.GenerateContentConfig(**config_args),
contents=user_prompt
)
text = response.text.strip()
return ModelResponse(text=text, true_confidence=None)
def create_model(provider: str, model_name: Optional[str] = None) -> BaseModel:
if provider == "mistral":
name = model_name or "mistral-large-latest"
return MistralWrapper(model_name=name)
elif provider == "gemini":
name = model_name or "gemini-2.5-flash"
return GeminiWrapper(model_name=name)
else:
raise ValueError(f"Unknown provider: {provider}. Use 'mistral' or 'gemini'.")