Skip to content

Commit 1255ab9

Browse files
authored
Merge pull request #8 from mmontan/refactor-testing-fakes-12025630112277388695
Refactor testing to use Fakes and Dependency Injection
2 parents c0500a8 + 990ffd2 commit 1255ab9

7 files changed

Lines changed: 392 additions & 274 deletions

File tree

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Dependency injection for AgentEngineClient."""
2+
3+
from agent_engine_cli.client import AgentEngineClient
4+
5+
6+
def get_client(project: str, location: str) -> AgentEngineClient:
7+
"""Create a new AgentEngineClient instance.
8+
9+
This function serves as a dependency injection point for the CLI.
10+
Tests can patch this function to return a fake client.
11+
12+
Args:
13+
project: Google Cloud project ID
14+
location: Google Cloud region
15+
16+
Returns:
17+
An instance of AgentEngineClient
18+
"""
19+
return AgentEngineClient(project=project, location=location)

src/agent_engine_cli/main.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
from collections.abc import MutableMapping
34
from typing import Annotated, Literal
45

56
import typer
@@ -10,8 +11,8 @@
1011

1112
from agent_engine_cli import __version__
1213
from agent_engine_cli.chat import run_chat
13-
from agent_engine_cli.client import AgentEngineClient
1414
from agent_engine_cli.config import ConfigurationError, resolve_project
15+
from agent_engine_cli.dependencies import get_client
1516

1617
console = Console()
1718

@@ -41,8 +42,8 @@ def list_agents(
4142
raise typer.Exit(code=1)
4243

4344
try:
44-
client = AgentEngineClient(project=project, location=location)
45-
agents = client.list_agents()
45+
client = get_client(project=project, location=location)
46+
agents = list(client.list_agents())
4647

4748
if not agents:
4849
console.print("No agents found.")
@@ -107,7 +108,7 @@ def get_agent(
107108
raise typer.Exit(code=1)
108109

109110
try:
110-
client = AgentEngineClient(project=project, location=location)
111+
client = get_client(project=project, location=location)
111112
agent = client.get_agent(agent_id)
112113

113114
# v1beta1 api_resource uses 'name' instead of 'resource_name'
@@ -247,7 +248,7 @@ def create_agent(
247248
raise typer.Exit(code=1)
248249

249250
try:
250-
client = AgentEngineClient(project=project, location=location)
251+
client = get_client(project=project, location=location)
251252
console.print(f"Creating agent '{escape(display_name)}'...")
252253

253254
agent = client.create_agent(
@@ -288,7 +289,7 @@ def delete_agent(
288289
raise typer.Exit()
289290

290291
try:
291-
client = AgentEngineClient(project=project, location=location)
292+
client = get_client(project=project, location=location)
292293
client.delete_agent(agent_id, force=force)
293294
console.print(f"[red]Agent '{escape(agent_id)}' deleted.[/red]")
294295
except Exception as e:
@@ -315,7 +316,7 @@ def list_sessions(
315316
raise typer.Exit(code=1)
316317

317318
try:
318-
client = AgentEngineClient(project=project, location=location)
319+
client = get_client(project=project, location=location)
319320
sessions = list(client.list_sessions(agent_id))
320321

321322
if not sessions:
@@ -383,8 +384,12 @@ def list_sandboxes(
383384
raise typer.Exit(code=1)
384385

385386
try:
386-
client = AgentEngineClient(project=project, location=location)
387-
sandboxes = client.list_sandboxes(agent_id)
387+
client = get_client(project=project, location=location)
388+
sandboxes = list(client.list_sandboxes(agent_id))
389+
390+
if not sandboxes:
391+
console.print("No sandboxes found.")
392+
return
388393

389394
table = Table(title="Sandboxes")
390395
table.add_column("Sandbox ID", style="cyan")
@@ -393,9 +398,7 @@ def list_sandboxes(
393398
table.add_column("Created")
394399
table.add_column("Expires")
395400

396-
has_items = False
397401
for sandbox in sandboxes:
398-
has_items = True
399402
# Extract sandbox ID from full resource name
400403
sandbox_name = getattr(sandbox, "name", "") or ""
401404
sandbox_id = sandbox_name.split("/")[-1] if sandbox_name else ""
@@ -430,10 +433,6 @@ def list_sandboxes(
430433
expire_time,
431434
)
432435

433-
if not has_items:
434-
console.print("No sandboxes found.")
435-
return
436-
437436
console.print(table)
438437
except Exception as e:
439438
console.print(f"[red]Error listing sandboxes: {escape(str(e))}[/red]")
@@ -459,8 +458,8 @@ def list_memories(
459458
raise typer.Exit(code=1)
460459

461460
try:
462-
client = AgentEngineClient(project=project, location=location)
463-
memories = client.list_memories(agent_id)
461+
client = get_client(project=project, location=location)
462+
memories = list(client.list_memories(agent_id))
464463

465464
table = Table(title="Memories")
466465
table.add_column("Memory ID", style="cyan")
@@ -482,7 +481,7 @@ def list_memories(
482481

483482
# Format scope dict as key=value pairs
484483
scope_raw = getattr(memory, "scope", None)
485-
if scope_raw and isinstance(scope_raw, dict):
484+
if scope_raw and isinstance(scope_raw, (dict, MutableMapping)):
486485
scope = ", ".join(f"{k}={v}" for k, v in scope_raw.items())
487486
else:
488487
scope = ""

tests/__init__.py

Whitespace-only changes.

tests/fakes.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Fake implementations for testing."""
2+
3+
from datetime import datetime
4+
from typing import Any, Iterator
5+
from dataclasses import dataclass
6+
from enum import Enum
7+
8+
from google.cloud.aiplatform_v1beta1.types import ReasoningEngine, ReasoningEngineSpec, Session, Memory
9+
10+
11+
class SandboxState(str, Enum):
12+
STATE_UNSPECIFIED = "STATE_UNSPECIFIED"
13+
STATE_RUNNING = "STATE_RUNNING"
14+
STATE_STOPPED = "STATE_STOPPED"
15+
16+
17+
@dataclass
18+
class Sandbox:
19+
"""Fake Sandbox since the real one is not easily importable."""
20+
name: str
21+
display_name: str
22+
state: Any
23+
create_time: datetime
24+
expire_time: datetime
25+
26+
27+
@dataclass
28+
class CreateAgentCall:
29+
"""Records the arguments passed to create_agent."""
30+
display_name: str
31+
identity_type: str
32+
service_account: str | None
33+
34+
35+
class FakeAgentEngineClient:
36+
"""Fake client for Agent Engine."""
37+
38+
def __init__(self, project: str, location: str):
39+
self.project = project
40+
self.location = location
41+
self._agents: dict[str, ReasoningEngine] = {}
42+
self._sessions: dict[str, list[Session]] = {}
43+
self._sandboxes: dict[str, list[Sandbox]] = {}
44+
self._memories: dict[str, list[Memory]] = {}
45+
self.create_agent_calls: list[CreateAgentCall] = []
46+
47+
def _get_full_name(self, resource_type: str, resource_id: str) -> str:
48+
return f"projects/{self.project}/locations/{self.location}/{resource_type}/{resource_id}"
49+
50+
def list_agents(self) -> Iterator[ReasoningEngine]:
51+
return iter(self._agents.values())
52+
53+
def get_agent(self, agent_id: str) -> ReasoningEngine:
54+
if "/" in agent_id:
55+
name = agent_id
56+
else:
57+
name = self._get_full_name("reasoningEngines", agent_id)
58+
59+
if name in self._agents:
60+
return self._agents[name]
61+
62+
for agent_name, agent in self._agents.items():
63+
if agent_name.endswith(f"/{agent_id}"):
64+
return agent
65+
66+
raise Exception(f"Agent {agent_id} not found")
67+
68+
def create_agent(
69+
self,
70+
display_name: str,
71+
identity_type: str,
72+
service_account: str | None = None,
73+
) -> ReasoningEngine:
74+
self.create_agent_calls.append(
75+
CreateAgentCall(display_name=display_name, identity_type=identity_type, service_account=service_account)
76+
)
77+
78+
agent_id = f"agent-{len(self._agents) + 1}"
79+
name = self._get_full_name("reasoningEngines", agent_id)
80+
81+
spec = ReasoningEngineSpec(agent_framework="langchain")
82+
agent = ReasoningEngine(
83+
name=name,
84+
display_name=display_name,
85+
spec=spec,
86+
create_time=datetime.now(),
87+
update_time=datetime.now(),
88+
)
89+
90+
self._agents[name] = agent
91+
return agent
92+
93+
def delete_agent(self, agent_id: str, force: bool = False) -> None:
94+
if "/" in agent_id:
95+
name = agent_id
96+
else:
97+
name = self._get_full_name("reasoningEngines", agent_id)
98+
99+
if name not in self._agents:
100+
for agent_name in list(self._agents.keys()):
101+
if agent_name.endswith(f"/{agent_id}"):
102+
name = agent_name
103+
break
104+
105+
if name in self._agents:
106+
if not force:
107+
if self._sessions.get(name) or self._memories.get(name) or self._sandboxes.get(name):
108+
raise Exception("Agent has resources, use force to delete")
109+
110+
del self._agents[name]
111+
self._sessions.pop(name, None)
112+
self._sandboxes.pop(name, None)
113+
self._memories.pop(name, None)
114+
else:
115+
raise Exception(f"Agent {agent_id} not found")
116+
117+
def list_sessions(self, agent_id: str) -> Iterator[Session]:
118+
agent = self.get_agent(agent_id)
119+
return iter(self._sessions.get(agent.name, []))
120+
121+
def list_sandboxes(self, agent_id: str) -> Iterator[Sandbox]:
122+
agent = self.get_agent(agent_id)
123+
return iter(self._sandboxes.get(agent.name, []))
124+
125+
def list_memories(self, agent_id: str) -> Iterator[Memory]:
126+
agent = self.get_agent(agent_id)
127+
return iter(self._memories.get(agent.name, []))

tests/test_fix_attribute_error.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
runner = CliRunner()
66

7-
@patch("agent_engine_cli.main.AgentEngineClient")
8-
def test_get_agent_with_none_class_methods(mock_client_class):
7+
@patch("agent_engine_cli.main.get_client")
8+
def test_get_agent_with_none_class_methods(mock_get_client):
99
"""Test get command when spec.class_methods is None (regression test)."""
1010

1111
# Mock spec with class_methods=None
@@ -25,7 +25,7 @@ def test_get_agent_with_none_class_methods(mock_client_class):
2525

2626
mock_client = MagicMock()
2727
mock_client.get_agent.return_value = mock_agent
28-
mock_client_class.return_value = mock_client
28+
mock_get_client.return_value = mock_client
2929

3030
result = runner.invoke(
3131
app, ["get", "agent1", "--project", "test-project", "--location", "us-central1"]

tests/test_get_effective_identity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
runner = CliRunner()
66

7-
@patch("agent_engine_cli.main.AgentEngineClient")
8-
def test_get_agent_shows_effective_identity(mock_client_class):
7+
@patch("agent_engine_cli.main.get_client")
8+
def test_get_agent_shows_effective_identity(mock_get_client):
99
"""Test get command shows effective identity."""
1010
mock_spec = MagicMock()
1111
mock_spec.effective_identity = "service-account@test.iam.gserviceaccount.com"
@@ -27,7 +27,7 @@ def test_get_agent_shows_effective_identity(mock_client_class):
2727

2828
mock_client = MagicMock()
2929
mock_client.get_agent.return_value = mock_agent
30-
mock_client_class.return_value = mock_client
30+
mock_get_client.return_value = mock_client
3131

3232
result = runner.invoke(
3333
app, ["get", "agent1", "--project", "test-project", "--location", "us-central1"]

0 commit comments

Comments
 (0)