-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstreamlit_ui.py
More file actions
176 lines (138 loc) · 5.28 KB
/
streamlit_ui.py
File metadata and controls
176 lines (138 loc) · 5.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import annotations
from typing import Literal, TypedDict
import asyncio
import os
import streamlit as st
import json
import logfire
from supabase import Client, create_client
from openai import AsyncOpenAI
####### HF Transformers ######
# from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# from dataclasses import dataclass
# from pydantic_ai import Agent
# # Load a Hugging Face model
# MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
# # Create a text generation pipeline
# hf_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ####### HF Transformers ######
# Import all the message part classes
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
SystemPromptPart,
UserPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
RetryPromptPart,
ModelMessagesTypeAdapter
)
from pydantic_ai_expert import pydantic_ai_expert, PydanticAIDeps
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
db_password = os.getenv("DB_PASSWORD")
supabase_url = os.getenv("SUPABASE_URL")
supabase_secret = os.getenv("SUPABASE_SECRET")
openai_client = AsyncOpenAI(api_key=openai_api_key)
supabase: Client = create_client(
supabase_url,
supabase_secret
)
# Configure logfire to suppress warnings (optional)
logfire.configure(send_to_logfire='never')
class ChatMessage(TypedDict):
"""Format of messages sent to the browser/API."""
role: Literal['user', 'model']
timestamp: str
content: str
def display_message_part(part):
"""
Display a single part of a message in the Streamlit UI.
Customize how you display system prompts, user prompts,
tool calls, tool returns, etc.
"""
# system-prompt
if part.part_kind == 'system-prompt':
with st.chat_message("system"):
st.markdown(f"**System**: {part.content}")
# user-prompt
elif part.part_kind == 'user-prompt':
with st.chat_message("user"):
st.markdown(part.content)
# text
elif part.part_kind == 'text':
with st.chat_message("assistant"):
st.markdown(part.content)
async def run_agent_with_streaming(user_input: str):
"""
Run the agent with streaming text for the user_input prompt,
while maintaining the entire conversation in `st.session_state.messages`.
"""
#print("Printing user_input", user_input)
# Prepare dependencies
deps = PydanticAIDeps(
supabase=supabase,
openai_client=openai_client
)
# deps = PydanticAIDeps(
# supabase=supabase,
# hf_pipeline=hf_pipeline # Replacing OpenAI client with Hugging Face
# )
# Run the agent in a stream
async with pydantic_ai_expert.run_stream(
user_input,
deps=deps,
message_history= st.session_state.messages[:-1], # pass entire conversation so far
) as result:
# We'll gather partial text to show incrementally
partial_text = ""
message_placeholder = st.empty()
# Render partial text as it arrives
async for chunk in result.stream_text(delta=True):
partial_text += chunk
message_placeholder.markdown(partial_text)
# Now that the stream is finished, we have a final result.
# Add new messages from this run, excluding user-prompt messages
filtered_messages = [msg for msg in result.new_messages()
if not (hasattr(msg, 'parts') and
any(part.part_kind == 'user-prompt' for part in msg.parts))]
st.session_state.messages.extend(filtered_messages)
# Add the final response to the messages
st.session_state.messages.append(
ModelResponse(parts=[TextPart(content=partial_text)])
)
async def main():
st.title("Environmental Expert")
st.write("Ask any question about the Environment and Earth Science")
# Initialize chat history in session state if not present
if "messages" not in st.session_state:
st.session_state.messages = []
# Display all messages from the conversation so far
# Each message is either a ModelRequest or ModelResponse.
# We iterate over their parts to decide how to display them.
for msg in st.session_state.messages:
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse):
for part in msg.parts:
display_message_part(part)
# Chat input for the user
user_input = st.chat_input("What questions do you have about Environment?")
if user_input:
# We append a new request to the conversation explicitly
st.session_state.messages.append(
ModelRequest(parts=[UserPromptPart(content=user_input)])
)
# Display user prompt in the UI
with st.chat_message("user"):
st.markdown(user_input)
# Display the assistant's partial response while streaming
with st.chat_message("assistant"):
# Actually run the agent now, streaming the text
await run_agent_with_streaming(user_input)
if __name__ == "__main__":
asyncio.run(main())