diff --git a/src/agent_engine_cli/chat.py b/src/agent_engine_cli/chat.py index 2824a17..097799f 100644 --- a/src/agent_engine_cli/chat.py +++ b/src/agent_engine_cli/chat.py @@ -119,6 +119,8 @@ async def run_chat( agent_id: str, user_id: str = "cli-user", debug: bool = False, + base_url: str | None = None, + api_version: str | None = None, ) -> None: """ Run an interactive chat session with an Agent Engine instance. @@ -129,6 +131,8 @@ async def run_chat( agent_id: Agent ID or full resource name. user_id: User ID for the chat session. debug: Enable verbose HTTP debug logging. + base_url: Optional override for the Vertex AI base URL. + api_version: Optional API version override. """ # Suppress vertexai experimental warnings try: @@ -149,7 +153,12 @@ async def run_chat( import vertexai # Get agent instance - client = vertexai.Client(project=project, location=location) + http_options: dict[str, str] = {} + if api_version: + http_options["api_version"] = api_version + if base_url: + http_options["base_url"] = base_url + client = vertexai.Client(project=project, location=location, http_options=http_options or None) resource_name = ( f"projects/{project}/locations/{location}/reasoningEngines/{agent_id}" ) diff --git a/src/agent_engine_cli/client.py b/src/agent_engine_cli/client.py index abe7187..ef07479 100644 --- a/src/agent_engine_cli/client.py +++ b/src/agent_engine_cli/client.py @@ -24,12 +24,14 @@ class AgentResource(Protocol): class AgentEngineClient: """Client for interacting with Vertex AI Agent Engine.""" - def __init__(self, project: str, location: str): + def __init__(self, project: str, location: str, *, base_url: str | None = None, api_version: str | None = None): """Initialize the client with project and location. Args: project: Google Cloud project ID location: Google Cloud region + base_url: Optional override for the Vertex AI base URL + api_version: Optional API version override """ self.project = project self.location = location @@ -38,10 +40,16 @@ def __init__(self, project: str, location: str): vertexai.init(project=project, location=location) + http_options: dict[str, str] = {} + if api_version: + http_options["api_version"] = api_version + if base_url: + http_options["base_url"] = base_url + self._client = vertexai.Client( project=project, location=location, - http_options={"api_version": "v1beta1"}, + http_options=http_options or None, ) def _resolve_resource_name(self, agent_id: str) -> str: diff --git a/src/agent_engine_cli/dependencies.py b/src/agent_engine_cli/dependencies.py index f990b63..d0b706d 100644 --- a/src/agent_engine_cli/dependencies.py +++ b/src/agent_engine_cli/dependencies.py @@ -3,7 +3,13 @@ from agent_engine_cli.client import AgentEngineClient -def get_client(project: str, location: str) -> AgentEngineClient: +def get_client( + project: str, + location: str, + *, + base_url: str | None = None, + api_version: str | None = None, +) -> AgentEngineClient: """Create a new AgentEngineClient instance. This function serves as a dependency injection point for the CLI. @@ -12,8 +18,10 @@ def get_client(project: str, location: str) -> AgentEngineClient: Args: project: Google Cloud project ID location: Google Cloud region + base_url: Optional override for the Vertex AI base URL + api_version: Optional API version override Returns: An instance of AgentEngineClient """ - return AgentEngineClient(project=project, location=location) + return AgentEngineClient(project=project, location=location, base_url=base_url, api_version=api_version) diff --git a/src/agent_engine_cli/main.py b/src/agent_engine_cli/main.py index 3c40197..5e5bb8c 100644 --- a/src/agent_engine_cli/main.py +++ b/src/agent_engine_cli/main.py @@ -33,6 +33,8 @@ def version(): def list_agents( location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all agents in the project.""" try: @@ -42,7 +44,7 @@ def list_agents( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) agents = list(client.list_agents()) if not agents: @@ -99,6 +101,8 @@ def get_agent( location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, full: Annotated[bool, typer.Option("--full", "-f", help="Show full JSON output")] = False, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Get details for a specific agent.""" try: @@ -108,7 +112,7 @@ def get_agent( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) agent = client.get_agent(agent_id) # v1beta1 api_resource uses 'name' instead of 'resource_name' @@ -239,6 +243,8 @@ def create_agent( str | None, typer.Option("--service-account", "-s", help="Service account email (only used with --identity service_account)"), ] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Create a new agent (without deploying code).""" try: @@ -248,7 +254,7 @@ def create_agent( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) console.print(f"Creating agent '{escape(display_name)}'...") agent = client.create_agent( @@ -274,6 +280,8 @@ def delete_agent( project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, force: Annotated[bool, typer.Option("--force", "-f", help="Force deletion of agents with sessions/memory")] = False, yes: Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation prompt")] = False, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Delete an agent.""" try: @@ -289,7 +297,7 @@ def delete_agent( raise typer.Exit() try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) client.delete_agent(agent_id, force=force) console.print(f"[red]Agent '{escape(agent_id)}' deleted.[/red]") except Exception as e: @@ -307,6 +315,8 @@ def list_sessions( agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")], location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all sessions for an agent.""" try: @@ -316,7 +326,7 @@ def list_sessions( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) sessions = list(client.list_sessions(agent_id)) if not sessions: @@ -375,6 +385,8 @@ def list_sandboxes( agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")], location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all sandboxes for an agent.""" try: @@ -384,7 +396,7 @@ def list_sandboxes( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) sandboxes = list(client.list_sandboxes(agent_id)) if not sandboxes: @@ -449,6 +461,8 @@ def list_memories( agent_id: Annotated[str, typer.Argument(help="Agent ID or full resource name")], location: Annotated[str, typer.Option("--location", "-l", help="Google Cloud region")], project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """List all memories for an agent.""" try: @@ -458,7 +472,7 @@ def list_memories( raise typer.Exit(code=1) try: - client = get_client(project=project, location=location) + client = get_client(project=project, location=location, base_url=base_url, api_version=api_version) memories = list(client.list_memories(agent_id)) table = Table(title="Memories") @@ -525,6 +539,8 @@ def chat( project: Annotated[str | None, typer.Option("--project", "-p", help="Google Cloud project ID (defaults to ADC project)")] = None, user: Annotated[str, typer.Option("--user", "-u", help="User ID for the chat session")] = "cli-user", debug: Annotated[bool, typer.Option("--debug", "-d", help="Enable verbose HTTP debug logging")] = False, + base_url: Annotated[str | None, typer.Option("--base-url", help="Override the Vertex AI base URL")] = None, + api_version: Annotated[str | None, typer.Option("--api-version", help="Override the API version")] = None, ) -> None: """Start an interactive chat session with an agent.""" try: @@ -540,6 +556,8 @@ def chat( agent_id=agent_id, user_id=user, debug=debug, + base_url=base_url, + api_version=api_version, )) except KeyboardInterrupt: console.print("\n[yellow]Chat session ended.[/yellow]") diff --git a/tests/test_client.py b/tests/test_client.py index 597b5b8..969b971 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -50,6 +50,60 @@ def test_init_custom_location(self, mock_vertexai, mock_agent_engines, mock_type assert client.location == "europe-west1" + def test_init_default_http_options(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that default init passes no http_options.""" + AgentEngineClient(project="test-project", location="us-central1") + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options=None, + ) + + def test_init_custom_api_version(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that custom api_version is passed through in http_options.""" + AgentEngineClient(project="test-project", location="us-central1", api_version="v1beta1") + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options={"api_version": "v1beta1"}, + ) + + def test_init_custom_base_url(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that custom base_url is passed through in http_options.""" + AgentEngineClient( + project="test-project", + location="us-central1", + base_url="https://custom-endpoint.example.com", + ) + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options={ + "base_url": "https://custom-endpoint.example.com", + }, + ) + + def test_init_custom_base_url_and_api_version(self, mock_vertexai, mock_agent_engines, mock_types): + """Test that both base_url and api_version are passed through together.""" + AgentEngineClient( + project="test-project", + location="us-central1", + base_url="https://staging.example.com", + api_version="v1", + ) + + mock_vertexai.Client.assert_called_once_with( + project="test-project", + location="us-central1", + http_options={ + "api_version": "v1", + "base_url": "https://staging.example.com", + }, + ) + def test_list_agents(self, mock_vertexai, mock_agent_engines, mock_types): """Test listing agents.""" mock_api_resource1 = MagicMock() diff --git a/tests/test_main.py b/tests/test_main.py index c1bf0a8..8ca5398 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,6 +3,7 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch +import pytest from typer.testing import CliRunner from google.cloud.aiplatform_v1beta1.types import ReasoningEngine, ReasoningEngineSpec, Session, Memory @@ -283,6 +284,8 @@ def test_chat_invokes_run_chat(self, mock_run_chat): agent_id="agent123", user_id="cli-user", debug=False, + base_url=None, + api_version=None, ) @patch("agent_engine_cli.main.run_chat") @@ -302,6 +305,8 @@ def test_chat_with_user_and_debug(self, mock_run_chat): agent_id="agent123", user_id="my-user", debug=True, + base_url=None, + api_version=None, ) @patch("agent_engine_cli.main.run_chat") @@ -331,7 +336,7 @@ def test_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["list", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) @patch("agent_engine_cli.main.resolve_project") def test_list_error_when_no_project(self, mock_resolve): @@ -400,6 +405,8 @@ def test_chat_uses_adc_project(self, mock_resolve, mock_run_chat): agent_id="agent1", user_id="cli-user", debug=False, + base_url=None, + api_version=None, ) @patch("agent_engine_cli.main.get_client") @@ -489,7 +496,7 @@ def test_sessions_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["sessions", "list", "agent1", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) class TestSandboxesListCommand: @@ -568,7 +575,7 @@ def test_sandboxes_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["sandboxes", "list", "agent1", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) class TestMemoriesListCommand: @@ -648,4 +655,65 @@ def test_memories_list_uses_adc_project(self, mock_resolve, mock_get_client): result = runner.invoke(app, ["memories", "list", "agent1", "--location", "us-central1"]) assert result.exit_code == 0 mock_resolve.assert_called_once_with(None) - mock_get_client.assert_called_once_with(project="adc-project", location="us-central1") + mock_get_client.assert_called_once_with(project="adc-project", location="us-central1", base_url=None, api_version=None) + + +class TestEndpointOverrideOptions: + """Tests for --base-url and --api-version options.""" + + @pytest.mark.parametrize("command,args", [ + (["list"], []), + (["get"], ["agent1"]), + (["create"], ["My Agent"]), + (["delete"], ["agent1"]), + (["chat"], ["agent1"]), + (["sessions", "list"], ["agent1"]), + (["sandboxes", "list"], ["agent1"]), + (["memories", "list"], ["agent1"]), + ]) + def test_help_shows_base_url_and_api_version(self, command, args): + """Test that --base-url and --api-version appear in help for all commands.""" + result = runner.invoke(app, command + args[:0] + ["--help"]) + assert result.exit_code == 0 + assert "--base-url" in result.stdout + assert "--api-version" in result.stdout + + @patch("agent_engine_cli.main.get_client") + def test_list_with_custom_endpoint_options(self, mock_get_client): + """Test list command passes custom base_url and api_version.""" + fake_client = FakeAgentEngineClient(project="test-project", location="us-central1") + mock_get_client.return_value = fake_client + + result = runner.invoke(app, [ + "list", "--project", "test-project", "--location", "us-central1", + "--base-url", "https://custom.example.com", + "--api-version", "v1", + ]) + assert result.exit_code == 0 + mock_get_client.assert_called_once_with( + project="test-project", + location="us-central1", + base_url="https://custom.example.com", + api_version="v1", + ) + + @patch("agent_engine_cli.main.run_chat") + def test_chat_with_custom_endpoint_options(self, mock_run_chat): + """Test chat command passes custom base_url and api_version to run_chat.""" + mock_run_chat.return_value = AsyncMock()() + + result = runner.invoke(app, [ + "chat", "agent123", "--project", "test-project", "--location", "us-central1", + "--base-url", "https://staging.example.com", + "--api-version", "v1", + ]) + assert result.exit_code == 0 + mock_run_chat.assert_called_once_with( + project="test-project", + location="us-central1", + agent_id="agent123", + user_id="cli-user", + debug=False, + base_url="https://staging.example.com", + api_version="v1", + )