diff --git a/docs/clients.md b/docs/clients.md index 23f6f00..e240ac7 100644 --- a/docs/clients.md +++ b/docs/clients.md @@ -36,7 +36,9 @@ Client instances can be configured with a base URL that is used when constructin ```python >>> cli = httpx.Client(url="https://www.httpbin.org") >>> r = cli.get("/json") ->>> r.url +>>> r + +>>> r.request.url 'https://www.httpbin.org/json' ``` @@ -56,7 +58,7 @@ You can override this behavior by explicitly specifying the default headers... ```python >>> headers = {"User-Agent": "dev", "Accept-Encoding": "gzip"} >>> cli = httpx.Client(headers=headers) ->>> r = cli.get("") +>>> r = cli.get("https://www.example.com/") ``` ## Configuring the connection pool @@ -64,7 +66,12 @@ You can override this behavior by explicitly specifying the default headers... The connection pool used by the client can be configured in order to customise the SSL context, the maximum number of concurrent connections, or the network backend. ```python ->>> transport = httpx.ConnectionPool(ssl_context=httpx.SSLNoVerify()) +>>> # Setup an SSL context to allow connecting to improperly configured SSL. +>>> no_verify = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +>>> no_verify.check_hostname = False +>>> no_verify.verify_mode = ssl.CERT_NONE +>>> # Instantiate a client with our custom SSL context. +>>> transport = httpx.ConnectionPool(ssl_context=no_verify) >>> cli = httpx.Client(transport=transport) ``` @@ -101,7 +108,7 @@ class MockTransport(httpx.Transport): self._response = response @contextlib.contextmanager - def send(request): + def send(self, request): yield response def close(self): diff --git a/src/ahttpx/__init__.py b/src/ahttpx/__init__.py index 6f61e0a..5b09c20 100644 --- a/src/ahttpx/__init__.py +++ b/src/ahttpx/__init__.py @@ -1,20 +1,19 @@ from ._client import * # Client, open_client -from ._models import * # Content, File, Files, Form, Headers, JSON, MultiPart, Response, Request +from ._models import * # Content, File, Files, Form, Headers, JSON, MultiPart, Response, Request, Text from ._network import * # NetworkBackend, NetworkStream -from ._pool import * # Connection, HTTPTransport, Transport -from ._urls import * # InvalidURL, QueryParams, URL +from ._pool import * # Connection, ConnectionPool, Transport +from ._urls import * # QueryParams, URL __all__ = [ "Client", "Connection", + "ConnectionPool", "Content", "File", "Files", "Form", "Headers", - "HTTPTransport", - "InvalidURL", "JSON", "MultiPart", "NetworkBackend", @@ -22,7 +21,17 @@ "open_client", "Response", "Request", + "Text", "Transport", "QueryParams", "URL", ] + + +# Modules names are deliberately private here. +# We fix-up the public API space so that class `__repr__` properly reflects this... +# +# >>> httpx.Client +# +for attr in __all__: + setattr(locals()[attr], '__module__', 'httpx') diff --git a/src/ahttpx/_client.py b/src/ahttpx/_client.py index 777936d..8c2740a 100644 --- a/src/ahttpx/_client.py +++ b/src/ahttpx/_client.py @@ -4,26 +4,29 @@ from typing import AsyncIterable, AsyncIterator, Mapping from ._models import Content, Headers, Response, Request -from ._pool import ConnectionPool +from ._pool import ConnectionPool, Transport from ._urls import URL -__all__ = ["Client", "open_client"] +__all__ = ["Client", "Content", "open_client"] class Client: def __init__( self, url: URL | str | None = None, - headers: Headers | Mapping[str, str] | None = None + headers: Headers | Mapping[str, str] | None = None, + transport: Transport | None = None, ): if url is None: url = "" if headers is None: headers = {"User-Agent": "dev"} + if transport is None: + transport = ConnectionPool() self.url = URL(url) self.headers = Headers(headers) - self.transport = ConnectionPool() + self.transport = transport self.via = RedirectMiddleware(self.transport) def build_request( @@ -109,8 +112,8 @@ def __repr__(self): return f"" -class RedirectMiddleware: - def __init__(self, transport) -> None: +class RedirectMiddleware(Transport): + def __init__(self, transport: Transport) -> None: self._transport = transport def is_redirect(self, response: Response) -> bool: diff --git a/src/ahttpx/_models.py b/src/ahttpx/_models.py index db0f8d5..d8a9eda 100644 --- a/src/ahttpx/_models.py +++ b/src/ahttpx/_models.py @@ -28,6 +28,7 @@ "MultiPart", "Response", "Request", + "Text", ] # We're using the same set as stdlib `http.HTTPStatus` here... @@ -747,6 +748,16 @@ def encode(self) -> tuple[Headers, bytes | AsyncIterable[bytes]]: return (headers, content) +class Text(Content): + def __init__(self, text: str) -> None: + self._text = text + + def encode(self) -> tuple[Headers, bytes | AsyncIterable[bytes]]: + content = self._text.encode("utf-8") + headers = Headers({"Content-Type": "text/plain; charset='utf-8", "Content-Length": str(len(content))}) + return (headers, content) + + class MultiPart(Content): def __init__( self, diff --git a/src/ahttpx/_pool.py b/src/ahttpx/_pool.py index fdd73fe..58230ba 100644 --- a/src/ahttpx/_pool.py +++ b/src/ahttpx/_pool.py @@ -11,7 +11,16 @@ from ._network import Lock, NetworkBackend, Semaphore, NetworkStream -class ConnectionPool: +class Transport: + @contextlib.asynccontextmanager + async def send(self, request: Request) -> typing.AsyncIterator[Response]: + raise NotImplementedError() + + async def close(self): + pass + + +class ConnectionPool(Transport): def __init__(self): self._connections = [] self._ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) @@ -21,6 +30,9 @@ def __init__(self): # Public API... @contextlib.asynccontextmanager async def send(self, request: Request) -> typing.AsyncIterator[Response]: + if request.url.scheme not in ("http", "https"): + raise ValueError(f"Invalid URL {str(request.url)!r}. Scheme must be http or https.") + async with self._limit_concurrency: try: connection = await self._get_connection(request) @@ -91,7 +103,7 @@ def __exit__( self.close() -class Connection: +class Connection(Transport): def __init__(self, stream: "NetworkStream", origin: URL | str): self._stream = stream self._origin = URL(origin) diff --git a/src/httpx/__init__.py b/src/httpx/__init__.py index 6f61e0a..5b09c20 100644 --- a/src/httpx/__init__.py +++ b/src/httpx/__init__.py @@ -1,20 +1,19 @@ from ._client import * # Client, open_client -from ._models import * # Content, File, Files, Form, Headers, JSON, MultiPart, Response, Request +from ._models import * # Content, File, Files, Form, Headers, JSON, MultiPart, Response, Request, Text from ._network import * # NetworkBackend, NetworkStream -from ._pool import * # Connection, HTTPTransport, Transport -from ._urls import * # InvalidURL, QueryParams, URL +from ._pool import * # Connection, ConnectionPool, Transport +from ._urls import * # QueryParams, URL __all__ = [ "Client", "Connection", + "ConnectionPool", "Content", "File", "Files", "Form", "Headers", - "HTTPTransport", - "InvalidURL", "JSON", "MultiPart", "NetworkBackend", @@ -22,7 +21,17 @@ "open_client", "Response", "Request", + "Text", "Transport", "QueryParams", "URL", ] + + +# Modules names are deliberately private here. +# We fix-up the public API space so that class `__repr__` properly reflects this... +# +# >>> httpx.Client +# +for attr in __all__: + setattr(locals()[attr], '__module__', 'httpx') diff --git a/src/httpx/_client.py b/src/httpx/_client.py index 0e92f04..7e8d702 100644 --- a/src/httpx/_client.py +++ b/src/httpx/_client.py @@ -4,26 +4,29 @@ from typing import Iterable, Iterator, Mapping from ._models import Content, Headers, Response, Request -from ._pool import ConnectionPool +from ._pool import ConnectionPool, Transport from ._urls import URL -__all__ = ["Client", "open_client"] +__all__ = ["Client", "Content", "open_client"] class Client: def __init__( self, url: URL | str | None = None, - headers: Headers | Mapping[str, str] | None = None + headers: Headers | Mapping[str, str] | None = None, + transport: Transport | None = None, ): if url is None: url = "" if headers is None: headers = {"User-Agent": "dev"} + if transport is None: + transport = ConnectionPool() self.url = URL(url) self.headers = Headers(headers) - self.transport = ConnectionPool() + self.transport = transport self.via = RedirectMiddleware(self.transport) def build_request( diff --git a/src/httpx/_models.py b/src/httpx/_models.py index 5932006..ac80e72 100644 --- a/src/httpx/_models.py +++ b/src/httpx/_models.py @@ -28,6 +28,7 @@ "MultiPart", "Response", "Request", + "Text", ] # We're using the same set as stdlib `http.HTTPStatus` here... @@ -747,6 +748,16 @@ def encode(self) -> tuple[Headers, bytes | Iterable[bytes]]: return (headers, content) +class Text(Content): + def __init__(self, text: str) -> None: + self._text = text + + def encode(self) -> tuple[Headers, bytes | Iterable[bytes]]: + content = self._text.encode("utf-8") + headers = Headers({"Content-Type": "text/plain; charset='utf-8", "Content-Length": str(len(content))}) + return (headers, content) + + class MultiPart(Content): def __init__( self, diff --git a/src/httpx/_pool.py b/src/httpx/_pool.py index 9f9859b..bea3a1d 100644 --- a/src/httpx/_pool.py +++ b/src/httpx/_pool.py @@ -11,7 +11,16 @@ from ._network import Lock, NetworkBackend, Semaphore, NetworkStream -class ConnectionPool: +class Transport: + @contextlib.contextmanager + def send(self, request: Request) -> typing.Iterator[Response]: + raise NotImplementedError() + + def close(self): + pass + + +class ConnectionPool(Transport): def __init__(self): self._connections = [] self._ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) @@ -21,6 +30,9 @@ def __init__(self): # Public API... @contextlib.contextmanager def send(self, request: Request) -> typing.Iterator[Response]: + if request.url.scheme not in ("http", "https"): + raise ValueError(f"Invalid URL {str(request.url)!r}. Scheme must be http or https.") + with self._limit_concurrency: try: connection = self._get_connection(request) @@ -91,7 +103,7 @@ def __exit__( self.close() -class Connection: +class Connection(Transport): def __init__(self, stream: "NetworkStream", origin: URL | str): self._stream = stream self._origin = URL(origin) diff --git a/tests/test_00_quickstart.py b/tests/test_00_quickstart.py index 4d53680..4249924 100644 --- a/tests/test_00_quickstart.py +++ b/tests/test_00_quickstart.py @@ -3,7 +3,7 @@ def test_cli(): cli = httpx.Client() - assert repr(cli) == "" + assert repr(cli) == "" # def test_post(httpbin): diff --git a/tests/test_01_clients.py b/tests/test_01_clients.py index dd71cbc..f4f7311 100644 --- a/tests/test_01_clients.py +++ b/tests/test_01_clients.py @@ -10,7 +10,7 @@ def cli(): def test_client(): client = httpx.Client() - assert repr(client) == "" + assert repr(client) == "" def test_get(httpbin, cli):