From 6aae761f64b6df559d1e8c02012aba25a171b624 Mon Sep 17 00:00:00 2001 From: mmontan Date: Wed, 11 Feb 2026 19:44:18 -0800 Subject: [PATCH] feat: Add --base-url and --api-version options for Vertex AI endpoint overrides Allow users to override the Vertex AI base URL and API version on all commands, useful for testing against staging environments, private endpoints, or different API versions. Co-Authored-By: Claude Opus 4.6 --- src/agent_engine_cli/chat.py | 11 +++- src/agent_engine_cli/client.py | 12 ++++- src/agent_engine_cli/dependencies.py | 12 ++++- src/agent_engine_cli/main.py | 32 +++++++++--- tests/test_client.py | 54 ++++++++++++++++++++ tests/test_main.py | 76 ++++++++++++++++++++++++++-- 6 files changed, 181 insertions(+), 16 deletions(-) diff --git a/src/agent_engine_cli/chat.py b/src/agent_engine_cli/chat.py index 2824a17..097799f 100644 --- a/src/agent_engine_cli/chat.py +++ b/src/agent_engine_cli/chat.py @@ -119,6 +119,8 @@ async def run_chat( agent_id: str, user_id: str = "cli-user", debug: bool = False, + base_url: str | None = None, + api_version: str | None = None, ) -> None: """ Run an interactive chat session with an Agent Engine instance. @@ -129,6 +131,8 @@ async def run_chat( agent_id: Agent ID or full resource name. user_id: User ID for the chat session. debug: Enable verbose HTTP debug logging. + base_url: Optional override for the Vertex AI base URL. + api_version: Optional API version override. """ # Suppress vertexai experimental warnings try: @@ -149,7 +153,12 @@ async def run_chat( import vertexai # Get agent instance - client = vertexai.Client(project=project, location=location) + http_options: dict[str, str] = {} + if api_version: + http_options["api_version"] = api_version + if base_url: + http_options["base_url"] = base_url + client = vertexai.Client(project=project, location=location, http_options=http_options or None) resource_name = ( f"projects/{project}/locations/{location}/reasoningEngines/{agent_id}" ) diff --git a/src/agent_engine_cli/client.py b/src/agent_engine_cli/client.py index abe7187..ef07479 100644 --- a/src/agent_engine_cli/client.py +++ b/src/agent_engine_cli/client.py @@ -24,12 +24,14 @@ class AgentResource(Protocol): class AgentEngineClient: """Client for interacting with Vertex AI Agent Engine.""" - def __init__(self, project: str, location: str): + def __init__(self, project: str, location: str, *, base_url: str | None = None, api_version: str | None = None): """Initialize the client with project and location. Args: project: Google Cloud project ID location: Google Cloud region + base_url: Optional override for the Vertex AI base URL + api_version: Optional API version override """ self.project = project self.location = location @@ -38,10 +40,16 @@ def __init__(self, project: str, location: str): vertexai.init(project=project, location=location) + http_options: dict[str, str] = {} + if api_version: + http_options["api_version"] = api_version + if base_url: + http_options["base_url"] = base_url + self._client = vertexai.Client( project=project, location=location, - http_options={"api_version": "v1beta1"}, + http_options=http_options or None, ) def _resolve_resource_name(self, agent_id: str) -> str: diff --git a/src/agent_engine_cli/dependencies.py b/src/agent_engine_cli/dependencies.py index f990b63..d0b706d 100644 --- a/src/agent_engine_cli/dependencies.py +++ b/src/agent_engine_cli/dependencies.py @@ -3,7 +3,13 @@ from agent_engine_cli.client import AgentEngineClient -def get_client(project: str, location: str) -> AgentEngineClient: +def get_client( + project: str, + location: str, + *, + base_url: str | None = None, + api_version: str | None = None, +) -> AgentEngineClient: """Create a new AgentEngineClient instance. This function serves as a dependency injection point for the CLI. @@ -12,8 +18,10 @@ def get_client(project: str, location: str) -> AgentEngineClient: Args: project: Google Cloud project ID location: Google Cloud region + base_url: Optional override for the Vertex AI base URL + api_version: Optional API version override Returns: An instance of AgentEngineClient """ - return AgentEngineClient(project=project, location=location) + return AgentEngineClient(project=project, location=location, base_url=base_url, api_version=api_version) diff --git a/src/agent_engine_cli/main.py b/src/agent_engine_cli/main.py index 3c40197..5e5bb8c 100644 --- a/src/agent_engine_cli/main.py +++ b/src/agent_engine_cli/main.py @@ -33,6 +33,8 @@ def version(): def list_agents( location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all agents in the project.""" try: @@ -42,7 +44,7 @@ def list_agents( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) agents = list(client.list_agents()) if not agents: @@ -99,6 +101,8 @@ def get_agent( location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, full: Annotated[bool, typer.Option("--full", "-f", help="Show full JSON output")] = False, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Get details for a specific agent.""" try: @@ -108,7 +112,7 @@ def get_agent( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) agent = client.get_agent(agent_id) # v1beta1 api_resource uses 'name' instead of 'resource_name' @@ -239,6 +243,8 @@ def create_agent( str | None, typer.Option("--service-account", "-s", help="Service account email (only used with --identity service_account)"), ] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Create a new agent (without deploying code).""" try: @@ -248,7 +254,7 @@ def create_agent( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) console.print(f"Creating agent '{escape(display_name)}'...") agent = client.create_agent( @@ -274,6 +280,8 @@ def delete_agent( project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, force: Annotated[bool, typer.Option("--force", "-f", help="Force deletion of agents with sessions/memory")] = False, yes: Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompt")] = False, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Delete an agent.""" try: @@ -289,7 +297,7 @@ def delete_agent( raise typer.Exit() try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) client.delete_agent(agent_id, force=force) console.print(f"[red]Agent '{escape(agent_id)}' deleted.[/red]") except Exception as e: @@ -307,6 +315,8 @@ def list_sessions( agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")], location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all sessions for an agent.""" try: @@ -316,7 +326,7 @@ def list_sessions( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) sessions = list(client.list_sessions(agent_id)) if not sessions: @@ -375,6 +385,8 @@ def list_sandboxes( agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")], location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all sandboxes for an agent.""" try: @@ -384,7 +396,7 @@ def list_sandboxes( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) sandboxes = list(client.list_sandboxes(agent_id)) if not sandboxes: @@ -449,6 +461,8 @@ def list_memories( agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")], location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all memories for an agent.""" try: @@ -458,7 +472,7 @@ def list_memories( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) memories = list(client.list_memories(agent_id)) table = Table(title="Memories") @@ -525,6 +539,8 @@ def chat( project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, user: Annotated[str, typer.Option("--user", "-u", help="User ID for the chat session")] = "cli-user", debug: Annotated[bool, typer.Option("--debug", "-d", help="Enable verbose HTTP debug logging")] = False, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Start an interactive chat session with an agent.""" try: @@ -540,6 +556,8 @@ def chat( agent_id=agent_id, user_id=user, debug=debug, + base_url=base_url, + api_version=api_version, )) except KeyboardInterrupt: console.print("\n[yellow]Chat session ended.[/yellow]") diff --git a/tests/test_client.py b/tests/test_client.py index 597b5b8..969b971 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -50,6 +50,60 @@ def test_init_custom_location(self, mock_vertexai, mock_agent_engines, mock_type assert client.location == "europe-west1" + def test_init_default_http_options(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that default init passes no http_options.""" + AgentEngineClient(project="test-project", location="us-central1") + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options=None, + ) + + def test_init_custom_api_version(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that custom api_version is passed through in http_options.""" + AgentEngineClient(project="test-project", location="us-central1", api_version="v1beta1") + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options={"api_version": "v1beta1"}, + ) + + def test_init_custom_base_url(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that custom base_url is passed through in http_options.""" + AgentEngineClient( + project="test-project", + location="us-central1", + base_url="https://custom-endpoint.example.com", + ) + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options={ + "base_url": "https://custom-endpoint.example.com", + }, + ) + + def test_init_custom_base_url_and_api_version(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that both base_url and api_version are passed through together.""" + AgentEngineClient( + project="test-project", + location="us-central1", + base_url="https://staging.example.com", + api_version="v1", + ) + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options={ + "api_version": "v1", + "base_url": "https://staging.example.com", + }, + ) + def test_list_agents(self, mock_vertexai, mock_agent_engines, mock_types): """Test listing agents.""" mock_api_resource1 = MagicMock() diff --git a/tests/test_main.py b/tests/test_main.py index c1bf0a8..8ca5398 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,6 +3,7 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch +import pytest from typer.testing import CliRunner from google.cloud.aiplatform_v1beta1.types import ReasoningEngine, ReasoningEngineSpec, Session, Memory @@ -283,6 +284,8 @@ def test_chat_invokes_run_chat(self, mock_run_chat): agent_id="agent123", user_id="cli-user", debug=False, + base_url=None, + api_version=None, ) @patch("agent_engine_cli.main.run_chat") @@ -302,6 +305,8 @@ def test_chat_with_user_and_debug(self, mock_run_chat): agent_id="agent123", user_id="my-user", debug=True, + base_url=None, + api_version=None, ) @patch("agent_engine_cli.main.run_chat") @@ -331,7 +336,7 @@ def test_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["list", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) @patch("agent_engine_cli.main.resolve_project") def test_list_error_when_no_project(self, mock_resolve): @@ -400,6 +405,8 @@ def test_chat_uses_adc_project(self, mock_resolve, mock_run_chat): agent_id="agent1", user_id="cli-user", debug=False, + base_url=None, + api_version=None, ) @patch("agent_engine_cli.main.get_client") @@ -489,7 +496,7 @@ def test_sessions_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["sessions", "list", "agent1", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) class TestSandboxesListCommand: @@ -568,7 +575,7 @@ def test_sandboxes_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["sandboxes", "list", "agent1", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) class TestMemoriesListCommand: @@ -648,4 +655,65 @@ def test_memories_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["memories", "list", "agent1", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) + + +class TestEndpointOverrideOptions: + """Tests for --base-url and --api-version options.""" + + @pytest.mark.parametrize("command,args", [ + (["list"], []), + (["get"], ["agent1"]), + (["create"], ["My Agent"]), + (["delete"], ["agent1"]), + (["chat"], ["agent1"]), + (["sessions", "list"], ["agent1"]), + (["sandboxes", "list"], ["agent1"]), + (["memories", "list"], ["agent1"]), + ]) + def test_help_shows_base_url_and_api_version(self, command, args): + """Test that --base-url and --api-version appear in help for all commands.""" + result = runner.invoke(app, command + args[:0] + ["--help"]) + assert result.exit_code == 0 + assert "--base-url" in result.stdout + assert "--api-version" in result.stdout + + @patch("agent_engine_cli.main.get_client") + def test_list_with_custom_endpoint_options(self, mock_get_client): + """Test list command passes custom base_url and api_version.""" + fake_client = FakeAgentEngineClient(project="test-project", location="us-central1") + mock_get_client.return_value = fake_client + + result = runner.invoke(app, [ + "list", "--project", "test-project", "--location", "us-central1", + "--base-url", "https://custom.example.com", + "--api-version", "v1", + ]) + assert result.exit_code == 0 + mock_get_client.assert_called_once_with( + project="test-project", + location="us-central1", + base_url="https://custom.example.com", + api_version="v1", + ) + + @patch("agent_engine_cli.main.run_chat") + def test_chat_with_custom_endpoint_options(self, mock_run_chat): + """Test chat command passes custom base_url and api_version to run_chat.""" + mock_run_chat.return_value = AsyncMock()() + + result = runner.invoke(app, [ + "chat", "agent123", "--project", "test-project", "--location", "us-central1", + "--base-url", "https://staging.example.com", + "--api-version", "v1", + ]) + assert result.exit_code == 0 + mock_run_chat.assert_called_once_with( + project="test-project", + location="us-central1", + agent_id="agent123", + user_id="cli-user", + debug=False, + base_url="https://staging.example.com", + api_version="v1", + )