Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
OPEN_ROUTER_API_KEY =''
API_KEY =''
BASE_URL=''
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ __pycache__/

# C extensions
*.so
.ai_agent
config.toml

# Distribution / packaging
.Python
Expand Down
46 changes: 23 additions & 23 deletions agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import json
from pathlib import Path
from typing import AsyncGenerator

from client.llm_client import LLMClient
from client.response import StreamEventType, ToolCall, ToolResultMessage
from context.manager import ContextManager
from tools.registry import create_default_registry
from config.config import Config

from .event import AgentEvent, AgentEventType
from .session import Session


class Agent:
def __init__(self):
self.client = LLMClient()
self.context_manager = ContextManager()
self.tool_registry = create_default_registry()
def __init__(self, config: Config):
self.config = config
self.session = Session(self.config)

async def run(self, message: str):
yield AgentEvent.agent_start(message)
self.context_manager.add_user_message(message)
self.session.context_manager.add_user_message(message)
final_response = None
async for event in self._agentic_loop():
yield event
Expand All @@ -29,13 +25,16 @@ async def run(self, message: str):
yield AgentEvent.agent_end(final_response)

async def _agentic_loop(self) -> AsyncGenerator[AgentEvent, None]:
while True:
max_turns = self.config.max_tunrs

while True and max_turns > self.session._turn_count():
self.session.increment_turn()
response_text = ""
tool_calls: list[ToolCall] = []
tool_schemas = self.tool_registry.get_schemas()
tool_schemas = self.session.tool_registry.get_schemas()

async for event in self.client.chat_completion(
self.context_manager.get_messages(),
async for event in self.session.client.chat_completion(
self.session.context_manager.get_messages(),
tools=tool_schemas if tool_schemas else None,
stream=True,
):
Expand Down Expand Up @@ -66,7 +65,10 @@ async def _agentic_loop(self) -> AsyncGenerator[AgentEvent, None]:
for tc in tool_calls
]

self.context_manager.add_assistant_message(
else:
break

self.session.context_manager.add_assistant_message(
response_text or None, tool_calls=api_tool_calls
)

Expand All @@ -77,10 +79,10 @@ async def _agentic_loop(self) -> AsyncGenerator[AgentEvent, None]:
tool_call.call_id, tool_call.name, tool_call.arguments
)

result = await self.tool_registry.invoke(
result = await self.session.tool_registry.invoke(
tool_call.name,
tool_call.arguments,
Path.cwd(),
self.config.cwd,
)
yield AgentEvent.tool_call_complete(
call_id=tool_call.call_id, name=tool_call.name, result=result
Expand All @@ -94,19 +96,17 @@ async def _agentic_loop(self) -> AsyncGenerator[AgentEvent, None]:
)

for tool_result in tool_call_results:
self.context_manager.add_tool_message(
self.session.context_manager.add_tool_message(
tool_result.tool_call_id, tool_result.content
)

if response_text:
yield AgentEvent.text_complete(response_text)

break

async def __aenter__(self) -> Agent:
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if self.client:
await self.client.close()
self.client = None
if self.session and self.session.client:
await self.session.client.close()
self.session = None
25 changes: 25 additions & 0 deletions agent/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import uuid
from datetime import datetime

from client.llm_client import LLMClient
from config.config import Config
from context.manager import ContextManager
from tools.registry import create_default_registry


class Session:
def __init__(self, config: Config):
self.config = config
self.client = LLMClient(config=config)
self.context_manager = ContextManager(config)
self.tool_registry = create_default_registry()
self.session_id = str(uuid.uuid7())
self.created_at = datetime.now()
self.updated_at = datetime.now()

self._turn_count = 0

def increment_turn(self) -> int:
self._turn_count += 1
self.updated_at = datetime.now()
return self._turn_count
29 changes: 19 additions & 10 deletions client/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
import asyncio
import os
from pathlib import Path
from typing import Any, AsyncGenerator, Optional

from dotenv import load_dotenv
from openai import APIConnectionError, APIError, AsyncOpenAI, RateLimitError

from .response import (StreamEvent, StreamEventType, TextDelta, TokenUsage,
ToolCall, ToolCallDelta, parse_tool_call_arguments)
from config.config import Config

from .response import (
StreamEvent,
StreamEventType,
TextDelta,
TokenUsage,
ToolCall,
ToolCallDelta,
parse_tool_call_arguments,
)

load_dotenv(Path(__file__).parent.parent / ".env")


class LLMClient:
def __init__(self):
def __init__(self, config: Config):
self._client: AsyncOpenAI | None = None
self._max_retry: int = 3
self.config = config

def get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = AsyncOpenAI(
api_key=os.getenv("OPEN_ROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
api_key=self.config.api_key,
base_url=self.config.base_url,
)
return self._client

Expand Down Expand Up @@ -52,7 +61,7 @@ async def chat_completion(
) -> AsyncGenerator[StreamEvent, None]:
client = self.get_client()
kwargs = {
"model": "mistralai/devstral-2512:free",
"model": self.config.model_name,
"messages": messages,
"stream": stream,
}
Expand Down Expand Up @@ -146,9 +155,9 @@ async def _stream_response(
)

if tool_call_delta.function.arguments:
tool_calls[idx][
"arguments"
] += tool_call_delta.function.arguments
tool_calls[idx]["arguments"] += (
tool_call_delta.function.arguments
)

yield StreamEvent(
type=StreamEventType.TOOL_CALL_DELTA,
Expand Down
65 changes: 65 additions & 0 deletions config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from pathlib import Path
from typing import Any, Optional

from pydantic import BaseModel, Field


class ModelConfig(BaseModel):
name: str = "mistralai/devstral-2512:free"
temperature: float = Field(default=1, ge=0.0, le=2.0)
context_window: int = 250_000


class Config(BaseModel):
model: ModelConfig = Field(default_factory=ModelConfig)
cwd: Path = Field(default_factory=Path.cwd)

max_tunrs: int = 100
max_tool_output_tokens: int = 50_000

developer_instructions: Optional[str] = None
user_instructions: Optional[str] = None

debug: bool = False

@property
def api_key(self) -> Optional[str]:
return os.environ.get("API_KEY")

@property
def base_url(self) -> Optional[str]:
return os.environ.get("BASE_URL")

@property
def model_name(self) -> str:
return self.model.name

@model_name.setter
def model_name(self, value: str) -> None:
self.model.name = value

@property
def temperature(self) -> float:
return self.model.temperature

@temperature.setter
def temperature(self, value: str) -> None:
self.model.temperature = value

def validate(self) -> list[str]:
errors: list[str] = []

if not self.api_key:
errors.append("No API key found. Set API_KEY environment variable")

if not self.base_url:
errors.append("No base url available.")

if not self.cwd.exists():
errors.append(f"Working directory does not exist: {self.cwd}")

return errors

def to_dict(self) -> dict[str, Any]:
return self.model_dump(mode="json")
110 changes: 110 additions & 0 deletions config/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
from pathlib import Path
from typing import Any, Optional

import tomli
from platformdirs import user_config_dir, user_data_dir

from config.config import Config
from utils.errors import ConfigError

logger = logging.getLogger(__name__)
CONFIG_FILE_NAME = "config.toml"

AGENT_MD_FILE = "AGENT.MD"


def get_config_dir() -> Path:
return Path(user_config_dir(".ai_agent"))


def get_data_dir() -> Path:
return Path(user_data_dir(".ai_agent"))


def get_system_config_path() -> Path:
return get_config_dir() / CONFIG_FILE_NAME


def _parse_toml(path: Path):
try:
with open(path, "rb") as f:
return tomli.load(f)
except tomli.TOMLDecodeError as e:
raise ConfigError("Invalid TOML in {path}: {e}", config_file=str(path)) from e
except (OSError, IOError) as e:
raise ConfigError(
"Failed to read config file {path}: {e}", config_file=str(path)
) from e


def _get_project_config(cwd: Path) -> Optional[Path]:
current = cwd.resolve()
agent_dir = current / ".ai_agent"

if agent_dir.is_dir():
config_file = agent_dir / CONFIG_FILE_NAME
if config_file.is_file():
return config_file

return None


def _get_agent_md_files(cwd: Path) -> Optional[Path]:
current = cwd.resolve()

if current.is_dir():
agent_md_file = current / AGENT_MD_FILE
if agent_md_file.is_file():
content = agent_md_file.read_text(encoding="utf-8")
return content

return None


def _merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _merge_dicts(result[key], value)
else:
result[key] = value

return result


def load_config(cwd: Optional[Path] = None) -> Config:
cwd = cwd or Path.cwd()

system_path = get_system_config_path()

config_dict: dict[str, Any] = {}

if system_path.is_file():
try:
config_dict = _parse_toml(system_path)
except ConfigError:
logger.warning(f"Skipping invalid system config: {system_path}")

project_path = _get_project_config(cwd)
if project_path:
try:
project_config_dict = _parse_toml(project_path)
config_dict = _merge_dicts(config_dict, project_config_dict)
except ConfigError:
logger.warning(f"Skipping invalid project config: {project_path}")

if "cwd" not in config_dict:
config_dict["cwd"] = cwd

if "developer_instructions" not in config_dict:
agent_md_content = _get_agent_md_files(cwd)
if agent_md_content:
config_dict["developer_instructions"] = agent_md_content

try:
config = Config(**config_dict)
except Exception as e:
raise ConfigError(f"Invalid configuration: {e}") from e

return config
Loading