-
Notifications
You must be signed in to change notification settings - Fork 84
feat: auto-detect aiohttp transport for async client #739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fern-support
wants to merge
2
commits into
main
Choose a base branch
from
fer-8644-aiohttp-autodetect
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import typing | ||
|
|
||
| import httpx | ||
|
|
||
| COHERE_DEFAULT_TIMEOUT = 300 | ||
|
|
||
| try: | ||
| import httpx_aiohttp | ||
| except ImportError: | ||
|
|
||
| class DefaultAioHttpClient(httpx.AsyncClient): # type: ignore | ||
| def __init__(self, **kwargs: typing.Any) -> None: | ||
| raise RuntimeError( | ||
| "To use the aiohttp client, install the aiohttp extra: " | ||
| "pip install cohere[aiohttp]" | ||
| ) | ||
|
|
||
| else: | ||
|
|
||
| class DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore | ||
| def __init__(self, **kwargs: typing.Any) -> None: | ||
| kwargs.setdefault("timeout", COHERE_DEFAULT_TIMEOUT) | ||
| kwargs.setdefault("follow_redirects", True) | ||
| super().__init__(**kwargs) | ||
|
|
||
|
|
||
| class DefaultAsyncHttpxClient(httpx.AsyncClient): | ||
| def __init__(self, **kwargs: typing.Any) -> None: | ||
| kwargs.setdefault("timeout", COHERE_DEFAULT_TIMEOUT) | ||
| kwargs.setdefault("follow_redirects", True) | ||
| super().__init__(**kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| import sys | ||
| import typing | ||
| import unittest | ||
| from unittest import mock | ||
|
|
||
| import httpx | ||
|
|
||
|
|
||
| class TestMakeDefaultAsyncClient(unittest.TestCase): | ||
| """Tests for _make_default_async_client in base_client.py.""" | ||
|
|
||
| def test_without_httpx_aiohttp_returns_httpx_async_client(self) -> None: | ||
| """When httpx_aiohttp is not installed, returns plain httpx.AsyncClient.""" | ||
| with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}): | ||
| # Re-import to pick up the mocked module state | ||
| from cohere.base_client import _make_default_async_client | ||
|
|
||
| client = _make_default_async_client(timeout=300, follow_redirects=True) | ||
| self.assertIsInstance(client, httpx.AsyncClient) | ||
| self.assertEqual(client.timeout.read, 300) | ||
| self.assertTrue(client.follow_redirects) | ||
|
|
||
| def test_without_httpx_aiohttp_follow_redirects_none(self) -> None: | ||
| """When follow_redirects is None, omits it from httpx.AsyncClient.""" | ||
| with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}): | ||
| from cohere.base_client import _make_default_async_client | ||
|
|
||
| client = _make_default_async_client(timeout=300, follow_redirects=None) | ||
| self.assertIsInstance(client, httpx.AsyncClient) | ||
| # httpx default is False when not specified | ||
| self.assertFalse(client.follow_redirects) | ||
|
|
||
| def test_with_httpx_aiohttp_returns_aiohttp_client(self) -> None: | ||
| """When httpx_aiohttp is installed, returns HttpxAiohttpClient.""" | ||
| try: | ||
| import httpx_aiohttp | ||
| except ImportError: | ||
| self.skipTest("httpx_aiohttp not installed") | ||
|
|
||
| from cohere.base_client import _make_default_async_client | ||
|
|
||
| client = _make_default_async_client(timeout=300, follow_redirects=True) | ||
| self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient) | ||
| self.assertEqual(client.timeout.read, 300) | ||
| self.assertTrue(client.follow_redirects) | ||
|
|
||
| def test_with_httpx_aiohttp_follow_redirects_none(self) -> None: | ||
| """When httpx_aiohttp is installed and follow_redirects is None, omits it.""" | ||
| try: | ||
| import httpx_aiohttp | ||
| except ImportError: | ||
| self.skipTest("httpx_aiohttp not installed") | ||
|
|
||
| from cohere.base_client import _make_default_async_client | ||
|
|
||
| client = _make_default_async_client(timeout=300, follow_redirects=None) | ||
| self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient) | ||
| # httpx default is False when not specified | ||
| self.assertFalse(client.follow_redirects) | ||
|
|
||
| def test_explicit_httpx_client_bypasses_autodetect(self) -> None: | ||
| """When user passes httpx_client explicitly, auto-detect is not used.""" | ||
| explicit_client = httpx.AsyncClient(timeout=60) | ||
| # Simulate what AsyncBaseCohere.__init__ does: | ||
| # httpx_client if httpx_client is not None else _make_default_async_client(...) | ||
| result = explicit_client if explicit_client is not None else None | ||
| self.assertIs(result, explicit_client) | ||
| self.assertEqual(result.timeout.read, 60) | ||
|
|
||
|
|
||
| class TestDefaultClients(unittest.TestCase): | ||
| """Tests for convenience classes in _default_clients.py.""" | ||
|
|
||
| def test_default_async_httpx_client_defaults(self) -> None: | ||
| """DefaultAsyncHttpxClient applies SDK defaults.""" | ||
| from cohere._default_clients import COHERE_DEFAULT_TIMEOUT, DefaultAsyncHttpxClient | ||
|
|
||
| client = DefaultAsyncHttpxClient() | ||
| self.assertIsInstance(client, httpx.AsyncClient) | ||
| self.assertEqual(client.timeout.read, COHERE_DEFAULT_TIMEOUT) | ||
| self.assertTrue(client.follow_redirects) | ||
|
|
||
| def test_default_async_httpx_client_overrides(self) -> None: | ||
| """DefaultAsyncHttpxClient allows overriding defaults.""" | ||
| from cohere._default_clients import DefaultAsyncHttpxClient | ||
|
|
||
| client = DefaultAsyncHttpxClient(timeout=60, follow_redirects=False) | ||
| self.assertEqual(client.timeout.read, 60) | ||
| self.assertFalse(client.follow_redirects) | ||
|
|
||
| def test_default_aiohttp_client_without_package(self) -> None: | ||
| """DefaultAioHttpClient raises RuntimeError when httpx_aiohttp not installed.""" | ||
| with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}): | ||
| # Need to reload the module to pick up the mock | ||
| import importlib | ||
| import cohere._default_clients | ||
|
|
||
| importlib.reload(cohere._default_clients) | ||
|
|
||
| with self.assertRaises(RuntimeError) as ctx: | ||
| cohere._default_clients.DefaultAioHttpClient() | ||
| self.assertIn("pip install cohere[aiohttp]", str(ctx.exception)) | ||
|
|
||
| # Reload again to restore original state | ||
| importlib.reload(cohere._default_clients) | ||
|
|
||
| def test_default_aiohttp_client_with_package(self) -> None: | ||
| """DefaultAioHttpClient works when httpx_aiohttp is installed.""" | ||
| try: | ||
| import httpx_aiohttp | ||
| except ImportError: | ||
| self.skipTest("httpx_aiohttp not installed") | ||
|
|
||
| from cohere._default_clients import COHERE_DEFAULT_TIMEOUT, DefaultAioHttpClient | ||
|
|
||
| client = DefaultAioHttpClient() | ||
| self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient) | ||
| self.assertEqual(client.timeout.read, COHERE_DEFAULT_TIMEOUT) | ||
| self.assertTrue(client.follow_redirects) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New entries break alphabetical ordering in lookup dict
Low Severity
"DefaultAioHttpClient"and"DefaultAsyncHttpxClient"are inserted between"DatasetPart"and"DatasetType"in the_dynamic_importsdict, splitting the"Dataset*"group and breaking the alphabetical sort convention maintained throughout the file. Alphabetically,"Default"sorts after"Debug"(since 'b' < 'f'), so these entries belong after"DebugV2ChatStreamResponse". The same misordering occurs in__all__, where the entries are placed before the"Debug*"entries instead of after them.Additional Locations (1)
src/cohere/__init__.py#L784-L786