Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/agent_engine_cli/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Dependency injection for AgentEngineClient."""

from agent_engine_cli.client import AgentEngineClient


def get_client(project: str, location: str) -> AgentEngineClient:
"""Create a new AgentEngineClient instance.

This function serves as a dependency injection point for the CLI.
Tests can patch this function to return a fake client.

Args:
project: Google Cloud project ID
location: Google Cloud region

Returns:
An instance of AgentEngineClient
"""
return AgentEngineClient(project=project, location=location)
35 changes: 17 additions & 18 deletions src/agent_engine_cli/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
from collections.abc import MutableMapping
from typing import Annotated, Literal

import typer
Expand All @@ -10,8 +11,8 @@

from agent_engine_cli import __version__
from agent_engine_cli.chat import run_chat
from agent_engine_cli.client import AgentEngineClient
from agent_engine_cli.config import ConfigurationError, resolve_project
from agent_engine_cli.dependencies import get_client

console = Console()

Expand Down Expand Up @@ -41,8 +42,8 @@ def list_agents(
raise typer.Exit(code=1)

try:
client = AgentEngineClient(project=project, location=location)
agents = client.list_agents()
client = get_client(project=project, location=location)
agents = list(client.list_agents())

if not agents:
console.print("No agents found.")
Expand Down Expand Up @@ -107,7 +108,7 @@ def get_agent(
raise typer.Exit(code=1)

try:
client = AgentEngineClient(project=project, location=location)
client = get_client(project=project, location=location)
agent = client.get_agent(agent_id)

# v1beta1 api_resource uses 'name' instead of 'resource_name'
Expand Down Expand Up @@ -247,7 +248,7 @@ def create_agent(
raise typer.Exit(code=1)

try:
client = AgentEngineClient(project=project, location=location)
client = get_client(project=project, location=location)
console.print(f"Creating agent '{escape(display_name)}'...")

agent = client.create_agent(
Expand Down Expand Up @@ -288,7 +289,7 @@ def delete_agent(
raise typer.Exit()

try:
client = AgentEngineClient(project=project, location=location)
client = get_client(project=project, location=location)
client.delete_agent(agent_id, force=force)
console.print(f"[red]Agent '{escape(agent_id)}' deleted.[/red]")
except Exception as e:
Expand All @@ -315,7 +316,7 @@ def list_sessions(
raise typer.Exit(code=1)

try:
client = AgentEngineClient(project=project, location=location)
client = get_client(project=project, location=location)
sessions = list(client.list_sessions(agent_id))

if not sessions:
Expand Down Expand Up @@ -383,8 +384,12 @@ def list_sandboxes(
raise typer.Exit(code=1)

try:
client = AgentEngineClient(project=project, location=location)
sandboxes = client.list_sandboxes(agent_id)
client = get_client(project=project, location=location)
sandboxes = list(client.list_sandboxes(agent_id))

if not sandboxes:
console.print("No sandboxes found.")
return

table = Table(title="Sandboxes")
table.add_column("Sandbox ID", style="cyan")
Expand All @@ -393,9 +398,7 @@ def list_sandboxes(
table.add_column("Created")
table.add_column("Expires")

has_items = False
for sandbox in sandboxes:
has_items = True
# Extract sandbox ID from full resource name
sandbox_name = getattr(sandbox, "name", "") or ""
sandbox_id = sandbox_name.split("/")[-1] if sandbox_name else ""
Expand Down Expand Up @@ -430,10 +433,6 @@ def list_sandboxes(
expire_time,
)

if not has_items:
console.print("No sandboxes found.")
return

console.print(table)
except Exception as e:
console.print(f"[red]Error listing sandboxes: {e}[/red]")
Expand All @@ -459,8 +458,8 @@ def list_memories(
raise typer.Exit(code=1)

try:
client = AgentEngineClient(project=project, location=location)
memories = client.list_memories(agent_id)
client = get_client(project=project, location=location)
memories = list(client.list_memories(agent_id))

if not memories:
console.print("No memories found.")
Expand All @@ -484,7 +483,7 @@ def list_memories(

# Format scope dict as key=value pairs
scope_raw = getattr(memory, "scope", None)
if scope_raw and isinstance(scope_raw, dict):
if scope_raw and isinstance(scope_raw, (dict, MutableMapping)):
scope = ", ".join(f"{k}={v}" for k, v in scope_raw.items())
else:
scope = ""
Expand Down
Empty file added tests/__init__.py
Empty file.
127 changes: 127 additions & 0 deletions tests/fakes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Fake implementations for testing."""

from datetime import datetime
from typing import Any, Iterator
from dataclasses import dataclass
from enum import Enum

from google.cloud.aiplatform_v1beta1.types import ReasoningEngine, ReasoningEngineSpec, Session, Memory


class SandboxState(str, Enum):
STATE_UNSPECIFIED = "STATE_UNSPECIFIED"
STATE_RUNNING = "STATE_RUNNING"
STATE_STOPPED = "STATE_STOPPED"


@dataclass
class Sandbox:
"""Fake Sandbox since the real one is not easily importable."""
name: str
display_name: str
state: Any
create_time: datetime
expire_time: datetime


@dataclass
class CreateAgentCall:
"""Records the arguments passed to create_agent."""
display_name: str
identity_type: str
service_account: str | None


class FakeAgentEngineClient:
"""Fake client for Agent Engine."""

def __init__(self, project: str, location: str):
self.project = project
self.location = location
self._agents: dict[str, ReasoningEngine] = {}
self._sessions: dict[str, list[Session]] = {}
self._sandboxes: dict[str, list[Sandbox]] = {}
self._memories: dict[str, list[Memory]] = {}
self.create_agent_calls: list[CreateAgentCall] = []

def _get_full_name(self, resource_type: str, resource_id: str) -> str:
return f"projects/{self.project}/locations/{self.location}/{resource_type}/{resource_id}"

def list_agents(self) -> Iterator[ReasoningEngine]:
return iter(self._agents.values())

def get_agent(self, agent_id: str) -> ReasoningEngine:
if "/" in agent_id:
name = agent_id
else:
name = self._get_full_name("reasoningEngines", agent_id)

if name in self._agents:
return self._agents[name]

for agent_name, agent in self._agents.items():
if agent_name.endswith(f"/{agent_id}"):
return agent

raise Exception(f"Agent {agent_id} not found")

def create_agent(
self,
display_name: str,
identity_type: str,
service_account: str | None = None,
) -> ReasoningEngine:
self.create_agent_calls.append(
CreateAgentCall(display_name=display_name, identity_type=identity_type, service_account=service_account)
)

agent_id = f"agent-{len(self._agents) + 1}"
name = self._get_full_name("reasoningEngines", agent_id)

spec = ReasoningEngineSpec(agent_framework="langchain")
agent = ReasoningEngine(
name=name,
display_name=display_name,
spec=spec,
create_time=datetime.now(),
update_time=datetime.now(),
)

self._agents[name] = agent
return agent

def delete_agent(self, agent_id: str, force: bool = False) -> None:
if "/" in agent_id:
name = agent_id
else:
name = self._get_full_name("reasoningEngines", agent_id)

if name not in self._agents:
for agent_name in list(self._agents.keys()):
if agent_name.endswith(f"/{agent_id}"):
name = agent_name
break

if name in self._agents:
if not force:
if self._sessions.get(name) or self._memories.get(name) or self._sandboxes.get(name):
raise Exception("Agent has resources, use force to delete")

del self._agents[name]
self._sessions.pop(name, None)
self._sandboxes.pop(name, None)
self._memories.pop(name, None)
else:
raise Exception(f"Agent {agent_id} not found")

def list_sessions(self, agent_id: str) -> Iterator[Session]:
agent = self.get_agent(agent_id)
return iter(self._sessions.get(agent.name, []))

def list_sandboxes(self, agent_id: str) -> Iterator[Sandbox]:
agent = self.get_agent(agent_id)
return iter(self._sandboxes.get(agent.name, []))

def list_memories(self, agent_id: str) -> Iterator[Memory]:
agent = self.get_agent(agent_id)
return iter(self._memories.get(agent.name, []))
6 changes: 3 additions & 3 deletions tests/test_fix_attribute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

runner = CliRunner()

@patch("agent_engine_cli.main.AgentEngineClient")
def test_get_agent_with_none_class_methods(mock_client_class):
@patch("agent_engine_cli.main.get_client")
def test_get_agent_with_none_class_methods(mock_get_client):
"""Test get command when spec.class_methods is None (regression test)."""

# Mock spec with class_methods=None
Expand All @@ -25,7 +25,7 @@ def test_get_agent_with_none_class_methods(mock_client_class):

mock_client = MagicMock()
mock_client.get_agent.return_value = mock_agent
mock_client_class.return_value = mock_client
mock_get_client.return_value = mock_client

result = runner.invoke(
app, ["get", "agent1", "--project", "test-project", "--location", "us-central1"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_get_effective_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

runner = CliRunner()

@patch("agent_engine_cli.main.AgentEngineClient")
def test_get_agent_shows_effective_identity(mock_client_class):
@patch("agent_engine_cli.main.get_client")
def test_get_agent_shows_effective_identity(mock_get_client):
"""Test get command shows effective identity."""
mock_spec = MagicMock()
mock_spec.effective_identity = "service-account@test.iam.gserviceaccount.com"
Expand All @@ -27,7 +27,7 @@ def test_get_agent_shows_effective_identity(mock_client_class):

mock_client = MagicMock()
mock_client.get_agent.return_value = mock_agent
mock_client_class.return_value = mock_client
mock_get_client.return_value = mock_client

result = runner.invoke(
app, ["get", "agent1", "--project", "test-project", "--location", "us-central1"]
Expand Down
Loading