-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm.py
More file actions
36 lines (31 loc) · 1.61 KB
/
llm.py
File metadata and controls
36 lines (31 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from config import API_PROVIDER, GOOGLE_API_KEY, OPENAI_API_KEY
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.schema import HumanMessage, AIMessage, SystemMessage
class LLMClient:
def __init__(self):
if API_PROVIDER == "GOOGLE":
self.client = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GOOGLE_API_KEY)
elif API_PROVIDER == "OPENAI":
self.client = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY)
else:
raise ValueError("Invalid API Provider")
# ✅ Define ChatPromptTemplate (Structured Messages)
self.chat_prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template("You are an AI assistant."),
HumanMessagePromptTemplate.from_template("Rewrite this customer review in a professional tone: {user_input}")
])
def generate_response(self, user_input):
"""Generate response using ChatPromptTemplate"""
formatted_messages = self.chat_prompt.format_messages(user_input=user_input)
return self.client.invoke(formatted_messages).content
def chat_completion(self, user_input):
"""Handle OpenAI & Gemini Messages with Role"""
messages = [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content=user_input)
]
return self.client.invoke(messages).content
# ✅ Initialize LLM
llm = LLMClient()