-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
60 lines (51 loc) · 1.87 KB
/
agent.py
File metadata and controls
60 lines (51 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# import required libraries and modules
import os
from dotenv import load_dotenv
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_groq import ChatGroq
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import AIMessage
# loads the environment variables from .env file
load_dotenv()
# Define the state
# This is the MEMORY of the agent.
# 'add_messages' ensures new chats are appended, not overwritten.
class State(TypedDict):
messages: Annotated[list, add_messages]
# Define the LLM Node
# This function calls the AI model.
def chatbot(state: State):
llm = ChatGroq(model="llama-3.1-8b-instant")
return {"messages": [llm.invoke(state["messages"])]}
# Build the Graph
builder = StateGraph(State)
builder.add_node("chatbot", chatbot)
builder.add_edge(START, "chatbot")
builder.add_edge("chatbot", END)
# This is the Stateful Part
# MemorySaver keeps track of the state in RAM.
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
# Run the Chat Loop
if __name__ == "__main__":
# a unique thread_id simulates a specific user session
config = {"configurable": {"thread_id": "1"}}
print("Bot: Hello! I have memory. Tell me your name, then ask me to repeat it.")
while True:
user_input = input("You: ")
if user_input.lower() in ["quit", "exit", "end"]:
break
# stream the Bot's response
events = graph.stream(
{"messages": [("user", user_input)]},
config,
stream_mode="values"
)
for event in events:
if "messages" in event:
last_message = event['messages'][-1]
if isinstance(last_message, AIMessage):
print(f"Bot: {last_message.content}")