-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathpydanticai_tools.py
More file actions
93 lines (75 loc) · 3 KB
/
pydanticai_tools.py
File metadata and controls
93 lines (75 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import asyncio
import logging
import os
import random
from datetime import datetime
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from openai import AsyncOpenAI
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider
from rich.logging import RichHandler
# Setup logging with rich
logging.basicConfig(level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("weekend_planner")
# Setup the OpenAI client to use Azure OpenAI
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "azure")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = AsyncOpenAI(
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
api_key=token_provider,
)
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
elif API_HOST == "ollama":
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
else:
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
def get_weather(city: str) -> dict:
logger.info(f"Getting weather for {city}")
if random.random() < 0.05:
return {
"city": city,
"temperature": 72,
"description": "Sunny",
}
else:
return {
"city": city,
"temperature": 60,
"description": "Rainy",
}
def get_activities(city: str, date: str) -> list:
logger.info(f"Getting activities for {city} on {date}")
return [
{"name": "Hiking", "location": city},
{"name": "Beach", "location": city},
{"name": "Museum", "location": city},
]
def get_current_date() -> str:
"""Gets the current date from the system and returns as a string in format YYYY-MM-DD."""
logger.info("Getting current date")
return datetime.now().strftime("%Y-%m-%d")
agent = Agent(
model,
system_prompt=(
"You help users plan their weekends and choose the best activities for the given weather."
"If an activity would be unpleasant in the weather, don't suggest it."
"Include the date of the weekend in your response."
),
tools=[get_weather, get_activities, get_current_date],
)
async def main():
result = await agent.run("what can I do for funzies this weekend in Seattle?")
print(result.output)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
logger.setLevel(logging.INFO)
asyncio.run(main())