diff --git a/tests/cli/test_cli_predictions.py b/tests/cli/test_cli_predictions.py index 06aafaf..1e9c3a8 100644 --- a/tests/cli/test_cli_predictions.py +++ b/tests/cli/test_cli_predictions.py @@ -3,8 +3,9 @@ from vlmrun.cli.cli import app -def test_list_predictions(runner, mock_client, config_file): +def test_list_predictions(runner, mock_client, config_file, monkeypatch): """Test list predictions command.""" + monkeypatch.setenv("COLUMNS", "200") result = runner.invoke(app, ["predictions", "list"]) assert result.exit_code == 0 assert "prediction1" in result.stdout @@ -57,8 +58,9 @@ def test_get_prediction_usage_display(runner, mock_client, config_file): assert "100" in result.stdout -def test_list_predictions_table_format(runner, mock_client, config_file): +def test_list_predictions_table_format(runner, mock_client, config_file, monkeypatch): """Test that list output is formatted correctly.""" + monkeypatch.setenv("COLUMNS", "200") result = runner.invoke(app, ["predictions", "list"]) assert result.exit_code == 0 assert "id" in result.stdout diff --git a/tests/conftest.py b/tests/conftest.py index b1287ed..fd82d72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,9 +87,10 @@ def execute(self, name: str, **kwargs): ) return prediction - def __init__(self, api_key=None, base_url=None): + def __init__(self, api_key=None, base_url=None, agent_base_url=None): self.api_key = api_key or "test-key" self.base_url = base_url or "https://api.vlm.run" + self.agent_base_url = agent_base_url or self.base_url self.timeout = 120.0 self.max_retries = 1 self.dataset = self.Dataset(self) diff --git a/vlmrun/cli/_cli/chat.py b/vlmrun/cli/_cli/chat.py index 1741b1e..ab594ba 100644 --- a/vlmrun/cli/_cli/chat.py +++ b/vlmrun/cli/_cli/chat.py @@ -488,9 +488,9 @@ def chat( help="Artifact output directory. [default: ~/.vlmrun/cache/artifacts/]", ), base_url: Optional[str] = typer.Option( - os.getenv("VLMRUN_BASE_URL", DEFAULT_BASE_URL), + os.getenv("VLMRUN_AGENT_BASE_URL", os.getenv("VLMRUN_BASE_URL", DEFAULT_BASE_URL)), "--base-url", - help="VLM Run Agent API base URL.", + help="VLM Run Agent API base URL. Falls back to VLMRUN_AGENT_BASE_URL, then VLMRUN_BASE_URL.", ), model: str = typer.Option( DEFAULT_MODEL, diff --git a/vlmrun/client/agent.py b/vlmrun/client/agent.py index 7f859f4..516a50d 100644 --- a/vlmrun/client/agent.py +++ b/vlmrun/client/agent.py @@ -280,7 +280,7 @@ def completions(self): error_type="missing_dependency", ) - base_url = f"{self._client.base_url}/openai" + base_url = f"{self._client.agent_base_url}/openai" openai_client = OpenAI( api_key=self._client.api_key, base_url=base_url, @@ -332,7 +332,7 @@ async def main(): error_type="missing_dependency", ) - base_url = f"{self._client.base_url}/openai" + base_url = f"{self._client.agent_base_url}/openai" async_openai_client = AsyncOpenAI( api_key=self._client.api_key, base_url=base_url, diff --git a/vlmrun/client/client.py b/vlmrun/client/client.py index 25c80e5..a00ec3f 100644 --- a/vlmrun/client/client.py +++ b/vlmrun/client/client.py @@ -41,6 +41,8 @@ class VLMRun: or VLMRUN_API_KEY environment variable. base_url: Base URL for API. Defaults to None, which falls back to VLMRUN_BASE_URL environment variable or https://api.vlm.run/v1. + agent_base_url: Base URL for the agent (OpenAI-compatible completions) endpoint. + Falls back to VLMRUN_AGENT_BASE_URL environment variable, then to base_url. timeout: Request timeout in seconds. Defaults to 120.0. max_retries: Maximum number of retry attempts for failed requests. Defaults to 5. files: Files resource for managing files @@ -50,6 +52,7 @@ class VLMRun: api_key: Optional[str] = None base_url: Optional[str] = None + agent_base_url: Optional[str] = None timeout: float = 120.0 max_retries: int = 5 @@ -79,6 +82,10 @@ def __post_init__(self): if self.base_url is None: self.base_url = os.getenv("VLMRUN_BASE_URL", DEFAULT_BASE_URL) + # Handle agent base URL (for OpenAI-compatible completions endpoint) + if self.agent_base_url is None: + self.agent_base_url = os.getenv("VLMRUN_AGENT_BASE_URL", self.base_url) + # Initialize requestor for API key validation requestor = APIRequestor( self, timeout=self.timeout, max_retries=self.max_retries diff --git a/vlmrun/types/abstract.py b/vlmrun/types/abstract.py index 701a6a8..4e99e3f 100644 --- a/vlmrun/types/abstract.py +++ b/vlmrun/types/abstract.py @@ -14,6 +14,7 @@ class VLMRunProtocol(Protocol): api_key: Optional[str] base_url: str + agent_base_url: str timeout: float files: Any datasets: Any