Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/agent_engine_cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ async def run_chat(
agent_id: str,
user_id: str = "cli-user",
debug: bool = False,
base_url: str | None = None,
api_version: str | None = None,
) -> None:
"""
Run an interactive chat session with an Agent Engine instance.
Expand All @@ -129,6 +131,8 @@ async def run_chat(
agent_id: Agent ID or full resource name.
user_id: User ID for the chat session.
debug: Enable verbose HTTP debug logging.
base_url: Optional override for the Vertex AI base URL.
api_version: Optional API version override.
"""
# Suppress vertexai experimental warnings
try:
Expand All @@ -149,7 +153,12 @@ async def run_chat(
import vertexai

# Get agent instance
client = vertexai.Client(project=project, location=location)
http_options: dict[str, str] = {}
if api_version:
http_options["api_version"] = api_version
if base_url:
http_options["base_url"] = base_url
client = vertexai.Client(project=project, location=location, http_options=http_options or None)
resource_name = (
f"projects/{project}/locations/{location}/reasoningEngines/{agent_id}"
)
Expand Down
12 changes: 10 additions & 2 deletions src/agent_engine_cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class AgentResource(Protocol):
class AgentEngineClient:
"""Client for interacting with Vertex AI Agent Engine."""

def __init__(self, project: str, location: str):
def __init__(self, project: str, location: str, *, base_url: str | None = None, api_version: str | None = None):
"""Initialize the client with project and location.

Args:
project: Google Cloud project ID
location: Google Cloud region
base_url: Optional override for the Vertex AI base URL
api_version: Optional API version override
"""
self.project = project
self.location = location
Expand All @@ -38,10 +40,16 @@ def __init__(self, project: str, location: str):

vertexai.init(project=project, location=location)

http_options: dict[str, str] = {}
if api_version:
http_options["api_version"] = api_version
if base_url:
http_options["base_url"] = base_url

self._client = vertexai.Client(
project=project,
location=location,
http_options={"api_version": "v1beta1"},
http_options=http_options or None,
)

def _resolve_resource_name(self, agent_id: str) -> str:
Expand Down
12 changes: 10 additions & 2 deletions src/agent_engine_cli/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from agent_engine_cli.client import AgentEngineClient


def get_client(project: str, location: str) -> AgentEngineClient:
def get_client(
project: str,
location: str,
*,
base_url: str | None = None,
api_version: str | None = None,
) -> AgentEngineClient:
"""Create a new AgentEngineClient instance.

This function serves as a dependency injection point for the CLI.
Expand All @@ -12,8 +18,10 @@ def get_client(project: str, location: str) -> AgentEngineClient:
Args:
project: Google Cloud project ID
location: Google Cloud region
base_url: Optional override for the Vertex AI base URL
api_version: Optional API version override

Returns:
An instance of AgentEngineClient
"""
return AgentEngineClient(project=project, location=location)
return AgentEngineClient(project=project, location=location, base_url=base_url, api_version=api_version)
32 changes: 25 additions & 7 deletions src/agent_engine_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def version():
def list_agents(
location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")],
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""List all agents in the project."""
try:
Expand All @@ -42,7 +44,7 @@ def list_agents(
raise typer.Exit(code=1)

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
agents = list(client.list_agents())

if not agents:
Expand Down Expand Up @@ -99,6 +101,8 @@ def get_agent(
location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")],
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
full: Annotated[bool, typer.Option("--full", "-f", help="Show full JSON output")] = False,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""Get details for a specific agent."""
try:
Expand All @@ -108,7 +112,7 @@ def get_agent(
raise typer.Exit(code=1)

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
agent = client.get_agent(agent_id)

# v1beta1 api_resource uses 'name' instead of 'resource_name'
Expand Down Expand Up @@ -239,6 +243,8 @@ def create_agent(
str | None,
typer.Option("--service-account", "-s", help="Service account email (only used with --identity service_account)"),
] = None,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""Create a new agent (without deploying code)."""
try:
Expand All @@ -248,7 +254,7 @@ def create_agent(
raise typer.Exit(code=1)

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
console.print(f"Creating agent '{escape(display_name)}'...")

agent = client.create_agent(
Expand All @@ -274,6 +280,8 @@ def delete_agent(
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
force: Annotated[bool, typer.Option("--force", "-f", help="Force deletion of agents with sessions/memory")] = False,
yes: Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompt")] = False,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""Delete an agent."""
try:
Expand All @@ -289,7 +297,7 @@ def delete_agent(
raise typer.Exit()

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
client.delete_agent(agent_id, force=force)
console.print(f"[red]Agent '{escape(agent_id)}' deleted.[/red]")
except Exception as e:
Expand All @@ -307,6 +315,8 @@ def list_sessions(
agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")],
location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")],
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""List all sessions for an agent."""
try:
Expand All @@ -316,7 +326,7 @@ def list_sessions(
raise typer.Exit(code=1)

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
sessions = list(client.list_sessions(agent_id))

if not sessions:
Expand Down Expand Up @@ -375,6 +385,8 @@ def list_sandboxes(
agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")],
location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")],
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""List all sandboxes for an agent."""
try:
Expand All @@ -384,7 +396,7 @@ def list_sandboxes(
raise typer.Exit(code=1)

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
sandboxes = list(client.list_sandboxes(agent_id))

if not sandboxes:
Expand Down Expand Up @@ -449,6 +461,8 @@ def list_memories(
agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")],
location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")],
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""List all memories for an agent."""
try:
Expand All @@ -458,7 +472,7 @@ def list_memories(
raise typer.Exit(code=1)

try:
client = get_client(project=project, location=location)
client = get_client(project=project, location=location, base_url=base_url, api_version=api_version)
memories = list(client.list_memories(agent_id))

table = Table(title="Memories")
Expand Down Expand Up @@ -525,6 +539,8 @@ def chat(
project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None,
user: Annotated[str, typer.Option("--user", "-u", help="User ID for the chat session")] = "cli-user",
debug: Annotated[bool, typer.Option("--debug", "-d", help="Enable verbose HTTP debug logging")] = False,
base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None,
api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None,
) -> None:
"""Start an interactive chat session with an agent."""
try:
Expand All @@ -540,6 +556,8 @@ def chat(
agent_id=agent_id,
user_id=user,
debug=debug,
base_url=base_url,
api_version=api_version,
))
except KeyboardInterrupt:
console.print("\n[yellow]Chat session ended.[/yellow]")
Expand Down
54 changes: 54 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,60 @@ def test_init_custom_location(self, mock_vertexai, mock_agent_engines, mock_type

assert client.location == "europe-west1"

def test_init_default_http_options(self, mock_vertexai, mock_agent_engines, mock_types):
"""Test that default init passes no http_options."""
AgentEngineClient(project="test-project", location="us-central1")

mock_vertexai.Client.assert_called_once_with(
project="test-project",
location="us-central1",
http_options=None,
)

def test_init_custom_api_version(self, mock_vertexai, mock_agent_engines, mock_types):
"""Test that custom api_version is passed through in http_options."""
AgentEngineClient(project="test-project", location="us-central1", api_version="v1beta1")

mock_vertexai.Client.assert_called_once_with(
project="test-project",
location="us-central1",
http_options={"api_version": "v1beta1"},
)

def test_init_custom_base_url(self, mock_vertexai, mock_agent_engines, mock_types):
"""Test that custom base_url is passed through in http_options."""
AgentEngineClient(
project="test-project",
location="us-central1",
base_url="https://custom-endpoint.example.com",
)

mock_vertexai.Client.assert_called_once_with(
project="test-project",
location="us-central1",
http_options={
"base_url": "https://custom-endpoint.example.com",
},
)

def test_init_custom_base_url_and_api_version(self, mock_vertexai, mock_agent_engines, mock_types):
"""Test that both base_url and api_version are passed through together."""
AgentEngineClient(
project="test-project",
location="us-central1",
base_url="https://staging.example.com",
api_version="v1",
)

mock_vertexai.Client.assert_called_once_with(
project="test-project",
location="us-central1",
http_options={
"api_version": "v1",
"base_url": "https://staging.example.com",
},
)

def test_list_agents(self, mock_vertexai, mock_agent_engines, mock_types):
"""Test listing agents."""
mock_api_resource1 = MagicMock()
Expand Down
Loading