Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,017 changes: 1,008 additions & 9 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ requests = "^2.0.0"
tokenizers = ">=0.15,<1"
types-requests = "^2.0.0"
typing_extensions = ">= 4.0.0"
aiohttp = {version = ">=3.0", optional = true}
httpx_aiohttp = {version = ">=0.1.8", optional = true}

[tool.poetry.extras]
aiohttp = ["aiohttp", "httpx_aiohttp"]

[tool.poetry.group.dev.dependencies]
mypy = "==1.13.0"
Expand Down
5 changes: 5 additions & 0 deletions src/cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@
)
from .bedrock_client import BedrockClient, BedrockClientV2
from .client import AsyncClient, Client
from ._default_clients import DefaultAioHttpClient, DefaultAsyncHttpxClient
from .client_v2 import AsyncClientV2, ClientV2
from .datasets import DatasetsCreateResponse, DatasetsGetResponse, DatasetsGetUsageResponse, DatasetsListResponse
from .embed_jobs import CreateEmbedJobRequestTruncate
Expand Down Expand Up @@ -440,6 +441,8 @@
"CreateEmbedJobResponse": ".types",
"Dataset": ".types",
"DatasetPart": ".types",
"DefaultAioHttpClient": "._default_clients",
"DefaultAsyncHttpxClient": "._default_clients",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New entries break alphabetical ordering in lookup dict

Low Severity

"DefaultAioHttpClient" and "DefaultAsyncHttpxClient" are inserted between "DatasetPart" and "DatasetType" in the _dynamic_imports dict, splitting the "Dataset*" group and breaking the alphabetical sort convention maintained throughout the file. Alphabetically, "Default" sorts after "Debug" (since 'b' < 'f'), so these entries belong after "DebugV2ChatStreamResponse". The same misordering occurs in __all__, where the entries are placed before the "Debug*" entries instead of after them.

Additional Locations (1)
Fix in Cursor Fix in Web

"DatasetType": ".types",
"DatasetValidationStatus": ".types",
"DatasetsCreateResponse": ".datasets",
Expand Down Expand Up @@ -779,6 +782,8 @@ def __dir__():
"DatasetsGetResponse",
"DatasetsGetUsageResponse",
"DatasetsListResponse",
"DefaultAioHttpClient",
"DefaultAsyncHttpxClient",
"DebugStreamedChatResponse",
"DebugV2ChatStreamResponse",
"DeleteConnectorResponse",
Expand Down
31 changes: 31 additions & 0 deletions src/cohere/_default_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import typing

import httpx

COHERE_DEFAULT_TIMEOUT = 300

try:
import httpx_aiohttp
except ImportError:

class DefaultAioHttpClient(httpx.AsyncClient): # type: ignore
def __init__(self, **kwargs: typing.Any) -> None:
raise RuntimeError(
"To use the aiohttp client, install the aiohttp extra: "
"pip install cohere[aiohttp]"
)

else:

class DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore
def __init__(self, **kwargs: typing.Any) -> None:
kwargs.setdefault("timeout", COHERE_DEFAULT_TIMEOUT)
kwargs.setdefault("follow_redirects", True)
super().__init__(**kwargs)


class DefaultAsyncHttpxClient(httpx.AsyncClient):
def __init__(self, **kwargs: typing.Any) -> None:
kwargs.setdefault("timeout", COHERE_DEFAULT_TIMEOUT)
kwargs.setdefault("follow_redirects", True)
super().__init__(**kwargs)
22 changes: 19 additions & 3 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,24 @@ def finetuning(self):
return self._finetuning


def _make_default_async_client(
timeout: float,
follow_redirects: typing.Optional[bool],
) -> httpx.AsyncClient:
try:
import httpx_aiohttp
except ImportError:
pass
else:
if follow_redirects is not None:
return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout, follow_redirects=follow_redirects)
return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout)

if follow_redirects is not None:
return httpx.AsyncClient(timeout=timeout, follow_redirects=follow_redirects)
return httpx.AsyncClient(timeout=timeout)


class AsyncBaseCohere:
"""
Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propagate to these functions.
Expand Down Expand Up @@ -1631,9 +1649,7 @@ def __init__(
headers=headers,
httpx_client=httpx_client
if httpx_client is not None
else httpx.AsyncClient(timeout=_defaulted_timeout, follow_redirects=follow_redirects)
if follow_redirects is not None
else httpx.AsyncClient(timeout=_defaulted_timeout),
else _make_default_async_client(timeout=_defaulted_timeout, follow_redirects=follow_redirects),
timeout=_defaulted_timeout,
)
self._raw_client = AsyncRawBaseCohere(client_wrapper=self._client_wrapper)
Expand Down
119 changes: 119 additions & 0 deletions tests/test_aiohttp_autodetect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import sys
import typing
import unittest
from unittest import mock

import httpx


class TestMakeDefaultAsyncClient(unittest.TestCase):
"""Tests for _make_default_async_client in base_client.py."""

def test_without_httpx_aiohttp_returns_httpx_async_client(self) -> None:
"""When httpx_aiohttp is not installed, returns plain httpx.AsyncClient."""
with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}):
# Re-import to pick up the mocked module state
from cohere.base_client import _make_default_async_client

client = _make_default_async_client(timeout=300, follow_redirects=True)
self.assertIsInstance(client, httpx.AsyncClient)
self.assertEqual(client.timeout.read, 300)
self.assertTrue(client.follow_redirects)

def test_without_httpx_aiohttp_follow_redirects_none(self) -> None:
"""When follow_redirects is None, omits it from httpx.AsyncClient."""
with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}):
from cohere.base_client import _make_default_async_client

client = _make_default_async_client(timeout=300, follow_redirects=None)
self.assertIsInstance(client, httpx.AsyncClient)
# httpx default is False when not specified
self.assertFalse(client.follow_redirects)

def test_with_httpx_aiohttp_returns_aiohttp_client(self) -> None:
"""When httpx_aiohttp is installed, returns HttpxAiohttpClient."""
try:
import httpx_aiohttp
except ImportError:
self.skipTest("httpx_aiohttp not installed")

from cohere.base_client import _make_default_async_client

client = _make_default_async_client(timeout=300, follow_redirects=True)
self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient)
self.assertEqual(client.timeout.read, 300)
self.assertTrue(client.follow_redirects)

def test_with_httpx_aiohttp_follow_redirects_none(self) -> None:
"""When httpx_aiohttp is installed and follow_redirects is None, omits it."""
try:
import httpx_aiohttp
except ImportError:
self.skipTest("httpx_aiohttp not installed")

from cohere.base_client import _make_default_async_client

client = _make_default_async_client(timeout=300, follow_redirects=None)
self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient)
# httpx default is False when not specified
self.assertFalse(client.follow_redirects)

def test_explicit_httpx_client_bypasses_autodetect(self) -> None:
"""When user passes httpx_client explicitly, auto-detect is not used."""
explicit_client = httpx.AsyncClient(timeout=60)
# Simulate what AsyncBaseCohere.__init__ does:
# httpx_client if httpx_client is not None else _make_default_async_client(...)
result = explicit_client if explicit_client is not None else None
self.assertIs(result, explicit_client)
self.assertEqual(result.timeout.read, 60)


class TestDefaultClients(unittest.TestCase):
"""Tests for convenience classes in _default_clients.py."""

def test_default_async_httpx_client_defaults(self) -> None:
"""DefaultAsyncHttpxClient applies SDK defaults."""
from cohere._default_clients import COHERE_DEFAULT_TIMEOUT, DefaultAsyncHttpxClient

client = DefaultAsyncHttpxClient()
self.assertIsInstance(client, httpx.AsyncClient)
self.assertEqual(client.timeout.read, COHERE_DEFAULT_TIMEOUT)
self.assertTrue(client.follow_redirects)

def test_default_async_httpx_client_overrides(self) -> None:
"""DefaultAsyncHttpxClient allows overriding defaults."""
from cohere._default_clients import DefaultAsyncHttpxClient

client = DefaultAsyncHttpxClient(timeout=60, follow_redirects=False)
self.assertEqual(client.timeout.read, 60)
self.assertFalse(client.follow_redirects)

def test_default_aiohttp_client_without_package(self) -> None:
"""DefaultAioHttpClient raises RuntimeError when httpx_aiohttp not installed."""
with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}):
# Need to reload the module to pick up the mock
import importlib
import cohere._default_clients

importlib.reload(cohere._default_clients)

with self.assertRaises(RuntimeError) as ctx:
cohere._default_clients.DefaultAioHttpClient()
self.assertIn("pip install cohere[aiohttp]", str(ctx.exception))

# Reload again to restore original state
importlib.reload(cohere._default_clients)

def test_default_aiohttp_client_with_package(self) -> None:
"""DefaultAioHttpClient works when httpx_aiohttp is installed."""
try:
import httpx_aiohttp
except ImportError:
self.skipTest("httpx_aiohttp not installed")

from cohere._default_clients import COHERE_DEFAULT_TIMEOUT, DefaultAioHttpClient

client = DefaultAioHttpClient()
self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient)
self.assertEqual(client.timeout.read, COHERE_DEFAULT_TIMEOUT)
self.assertTrue(client.follow_redirects)
Loading