diff --git a/src/viam/app/viam_client.py b/src/viam/app/viam_client.py index 2071dd2d5..ae1f4f5c8 100644 --- a/src/viam/app/viam_client.py +++ b/src/viam/app/viam_client.py @@ -88,7 +88,7 @@ async def create_from_dial_options(cls, dial_options: DialOptions, app_url: Opti self._dial_options = dial_options if app_url is None: app_url = "app.viam.com" - self._channel = await _dial_app(app_url) + self._channel = await _dial_app(app_url, dial_options) access_token = await _get_access_token(self._channel, dial_options.auth_entity, dial_options) self._metadata = {"authorization": f"Bearer {access_token}"} return self diff --git a/src/viam/rpc/dial.py b/src/viam/rpc/dial.py index 80016f617..78382fe70 100644 --- a/src/viam/rpc/dial.py +++ b/src/viam/rpc/dial.py @@ -416,5 +416,5 @@ async def dial_direct(address: str, options: Optional[DialOptions] = None) -> Ch return await _dial_direct(address, options) -async def _dial_app(app_url: str) -> Channel: - return await _dial_direct(app_url) +async def _dial_app(app_url: str, options: Optional[DialOptions] = None) -> Channel: + return await _dial_direct(app_url, options) diff --git a/tests/test_viam_client.py b/tests/test_viam_client.py index 71af619c4..9f0ab65a8 100644 --- a/tests/test_viam_client.py +++ b/tests/test_viam_client.py @@ -44,6 +44,34 @@ async def test_sets_fields(self): assert DIAL_OPTIONS.auth_entity is not None assert client._metadata == {"authorization": f"Bearer {ACCESS_TOKEN}"} + async def test_passes_dial_options_to_dial_app(self): + async with ChannelFor([]) as channel: + with patch("viam.app.viam_client._dial_app") as patched_dial: + patched_dial.return_value = channel + with patch("viam.app.viam_client._get_access_token") as patched_auth: + patched_auth.return_value = "MY_ACCESS_TOKEN" + + creds = Credentials("api-key", "SOME_API_KEY") + dial_options = DialOptions(credentials=creds, auth_entity=str(uuid4()), insecure=True) + + await ViamClient.create_from_dial_options(dial_options) + + patched_dial.assert_called_once_with("app.viam.com", dial_options) + + async def test_passes_dial_options_with_custom_url(self): + async with ChannelFor([]) as channel: + with patch("viam.app.viam_client._dial_app") as patched_dial: + patched_dial.return_value = channel + with patch("viam.app.viam_client._get_access_token") as patched_auth: + patched_auth.return_value = "MY_ACCESS_TOKEN" + + creds = Credentials("api-key", "SOME_API_KEY") + dial_options = DialOptions(credentials=creds, auth_entity=str(uuid4()), insecure=True) + + await ViamClient.create_from_dial_options(dial_options, app_url="localhost:8080") + + patched_dial.assert_called_once_with("localhost:8080", dial_options) + async def test_clients(self): async with ChannelFor([]) as channel: with patch("viam.app.viam_client._dial_app") as patched_dial: