-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
101 lines (85 loc) · 3.72 KB
/
agent.py
File metadata and controls
101 lines (85 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from typing import Annotated, TypedDict
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import ToolMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.checkpoint.memory import MemorySaver
def load_api_key(file_path: str) -> str:
"""Read API key from a text file and return it."""
try:
with open(file_path, 'r') as file:
key = file.read().strip()
if not key:
raise ValueError("API key file is empty.")
return key
except FileNotFoundError:
raise FileNotFoundError(f"API key file not found: {file_path}")
except Exception as e:
raise RuntimeError(f"Error reading API key: {e}")
TAVILY_API_KEY = load_api_key("TavilyKey.txt")
LANGCHAIN_API_KEY = load_api_key("LangSmithKey.txt")
# Environment Setup
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Agent State Definition
class State(TypedDict):
messages: Annotated[list, add_messages]
# Manual ToolNode implementation
class BasicToolNode:
def __init__(self, tools: list):
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, state: dict):
messages = state.get("messages", [])
if not messages:
raise ValueError("No messages found in state")
last_message = messages[-1]
outputs = []
for tool_call in last_message.tool_calls:
tool = self.tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
outputs.append(
ToolMessage(
content=str(observation),
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
def basic_tools_condition(state: dict):
messages = state.get("messages", [])
if messages and getattr(messages[-1], "tool_calls", None):
return "tools"
return END
# Initialize Tools
search_tool = TavilySearchResults(max_results=3)
tools = [search_tool]
tool_node = BasicToolNode(tools)
# Initialize LLM
llm = ChatOllama(model="gpt-oss:20b", temperature=0)
llm_with_tools = llm.bind_tools(tools)
# Define the Agent Logic
def agent_node(state: State):
system_prompt = SystemMessage(
content="You are a professional, helpful, and friendly AI Assistant. "
"1. If the user greets you (e.g., 'Hi', 'Hello', 'Hey'), respond with a warm and professional greeting. "
"2. When the user asks a specific question or query, use the available search tools to provide an accurate and detailed answer. "
"3. IMPORTANT: NEVER mention that you are using tools, searching the web, or calling functions. Just provide the final answer directly. "
"4. Maintain a high-quality, professional tone at all times. "
"5. Do not explain your thought process or internal logic."
)
messages = [system_prompt] + state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
# Build the Graph
workflow = StateGraph(State)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", basic_tools_condition)
workflow.add_edge("tools", "agent")
# Initialize Memory
memory = MemorySaver()
chat_graph = workflow.compile(checkpointer=memory)