Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
uds: str | None = None,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
on_capacity_update: typing.Callable[[], typing.Any] | None = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
Expand All @@ -65,6 +66,7 @@ def __init__(
self._connect_failed: bool = False
self._request_lock = AsyncLock()
self._socket_options = socket_options
self._on_capacity_update = on_capacity_update

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand All @@ -89,6 +91,7 @@ async def handle_async_request(self, request: Request) -> Response:
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
on_capacity_update=self._on_capacity_update,
)
else:
self._connection = AsyncHTTP11Connection(
Expand Down Expand Up @@ -184,6 +187,11 @@ def is_available(self) -> bool:
)
return self._connection.is_available()

def max_concurrent_requests(self) -> int:
if self._connection is None:
return 1
return self._connection.max_concurrent_requests()

def has_expired(self) -> bool:
if self._connection is None:
return self._connect_failed
Expand Down
80 changes: 61 additions & 19 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
on_capacity_update=self._connection_capacity_updated,
)
elif origin.scheme == b"http":
from .http_proxy import AsyncForwardHTTPConnection
Expand All @@ -150,6 +151,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
on_capacity_update=self._connection_capacity_updated,
)
from .http_proxy import AsyncTunnelHTTPConnection

Expand All @@ -163,6 +165,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
on_capacity_update=self._connection_capacity_updated,
)

return AsyncHTTPConnection(
Expand All @@ -176,6 +179,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
uds=self._uds,
network_backend=self._network_backend,
socket_options=self._socket_options,
on_capacity_update=self._connection_capacity_updated,
)

@property
Expand Down Expand Up @@ -289,55 +293,93 @@ def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
# log: "closing expired connection"
self._connections.remove(connection)
closing_connections.append(connection)
elif (
connection.is_idle()
and sum(connection.is_idle() for connection in self._connections)
> self._max_keepalive_connections
):

idle_connection_count = sum(
connection.is_idle() for connection in self._connections
)
for connection in list(self._connections):
if idle_connection_count <= self._max_keepalive_connections:
break
if connection.is_idle():
# log: "closing idle connection"
self._connections.remove(connection)
closing_connections.append(connection)
idle_connection_count -= 1

# Assign queued requests to connections.
queued_requests = [request for request in self._requests if request.is_queued()]
connection_request_count = dict.fromkeys(self._connections, 0)
for request in self._requests:
request_connection = request.connection
if request_connection in connection_request_count:
connection_request_count[request_connection] += 1

for pool_request in queued_requests:
origin = pool_request.request.url.origin
available_connections = [
connection
for connection in self._connections
if connection.can_handle_request(origin) and connection.is_available()
]
idle_connections = [
connection for connection in self._connections if connection.is_idle()
]
available_connection = next(
(
connection
for connection in self._connections
if (
connection.can_handle_request(origin)
and connection.is_available()
and connection_request_count[connection]
< self._max_concurrent_requests(connection)
)
),
None,
)

# There are three cases for how we may be able to handle the request:
#
# 1. There is an existing connection that can handle the request.
# 2. We can create a new connection to handle the request.
# 3. We can close an idle connection and then create a new connection
# to handle the request.
if available_connections:
if available_connection is not None:
# log: "reusing existing connection"
connection = available_connections[0]
connection = available_connection
pool_request.assign_to_connection(connection)
connection_request_count[connection] += 1
elif len(self._connections) < self._max_connections:
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
elif idle_connections:
connection_request_count[connection] = 1
else:
idle_connection = next(
(
connection
for connection in self._connections
if connection.is_idle()
),
None,
)
if idle_connection is None:
continue
# log: "closing idle connection"
connection = idle_connections[0]
self._connections.remove(connection)
closing_connections.append(connection)
self._connections.remove(idle_connection)
closing_connections.append(idle_connection)
# log: "creating new connection"
connection = self.create_connection(origin)
self._connections.append(connection)
pool_request.assign_to_connection(connection)
connection_request_count[connection] = 1

return closing_connections

def _max_concurrent_requests(self, connection: AsyncConnectionInterface) -> int:
try:
return int(connection.max_concurrent_requests())
except AttributeError: # pragma: nocover
return 1

async def _connection_capacity_updated(self) -> None:
with self._optional_thread_lock:
closing = self._assign_requests_to_connections()
await self._close_connections(closing)

async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
# Close connections which have been removed from the pool.
with AsyncShieldCancellation():
Expand Down
3 changes: 3 additions & 0 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def is_available(self) -> bool:
# acquired from the connection pool for any other request.
return self._state == HTTPConnectionState.IDLE

def max_concurrent_requests(self) -> int:
return 1

def has_expired(self) -> bool:
now = time.monotonic()
keepalive_expired = self._expire_at is not None and now > self._expire_at
Expand Down
63 changes: 52 additions & 11 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def __init__(
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: float | None = None,
on_capacity_update: typing.Callable[[], typing.Any] | None = None,
):
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: float | None = keepalive_expiry
self._on_capacity_update = on_capacity_update
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
self._state = HTTPConnectionState.IDLE
self._expire_at: float | None = None
Expand All @@ -74,6 +76,7 @@ def __init__(
| h2.events.StreamReset,
],
] = {}
self._closed_streams: set[int] = set()

# Connection terminated events are stored as state since
# we need to handle them for all streams.
Expand All @@ -95,6 +98,8 @@ async def handle_async_request(self, request: Request) -> Response:

async with self._state_lock:
if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
previous_state = self._state
previous_expire_at = self._expire_at
self._request_count += 1
self._expire_at = None
self._state = HTTPConnectionState.ACTIVE
Expand Down Expand Up @@ -128,14 +133,21 @@ async def handle_async_request(self, request: Request) -> Response:
for _ in range(local_settings_max_streams - self._max_streams):
await self._max_streams_semaphore.acquire()

await self._max_streams_semaphore.acquire()
if not self._max_streams_semaphore.acquire_nowait():
async with self._state_lock:
self._request_count -= 1
if not self._events: # pragma: nocover
self._state = previous_state
self._expire_at = previous_expire_at
raise ConnectionNotAvailable()

try:
stream_id = self._h2_state.get_next_available_stream_id()
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
await self._max_streams_semaphore.release()
raise ConnectionNotAvailable()

try:
Expand Down Expand Up @@ -380,6 +392,10 @@ async def _receive_events(
),
):
if event.stream_id in self._events:
if isinstance(
event, (h2.events.StreamEnded, h2.events.StreamReset)
):
self._closed_streams.add(event.stream_id)
self._events[event.stream_id].append(event)

elif isinstance(event, h2.events.ConnectionTerminated):
Expand All @@ -399,27 +415,48 @@ async def _receive_remote_settings_change(
self._h2_state.local_settings.max_concurrent_streams,
)
if new_max_streams and new_max_streams != self._max_streams:
while new_max_streams > self._max_streams:
active_stream_count = len(self._events)
old_available_streams = max(0, self._max_streams - active_stream_count)
new_available_streams = max(0, new_max_streams - active_stream_count)
self._max_streams = new_max_streams
while new_available_streams > old_available_streams:
await self._max_streams_semaphore.release()
self._max_streams += 1
while new_max_streams < self._max_streams:
old_available_streams += 1
while new_available_streams < old_available_streams:
await self._max_streams_semaphore.acquire()
self._max_streams -= 1
old_available_streams -= 1
if self._on_capacity_update is not None:
await self._on_capacity_update()

async def _response_closed(self, stream_id: int) -> None:
await self._max_streams_semaphore.release()
stream_was_reset = stream_id not in self._closed_streams
if stream_was_reset:
# Keep h2's stream state aligned without blocking close/cancel on I/O.
# Any pending RST_STREAM data will be flushed by the next write.
try:
self._h2_state.reset_stream(stream_id)
except (
h2.exceptions.NoSuchStreamError,
h2.exceptions.StreamClosedError,
h2.exceptions.ProtocolError,
):
pass
if len(self._events) <= self._max_streams:
await self._max_streams_semaphore.release()
self._closed_streams.discard(stream_id)
del self._events[stream_id]
async with self._state_lock:
if self._connection_terminated and not self._events:
await self.aclose()

elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
self._expire_at = now + self._keepalive_expiry
if self._used_all_stream_ids: # pragma: nocover
if stream_was_reset or self._used_all_stream_ids:
await self.aclose()
else:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
self._expire_at = now + self._keepalive_expiry

async def aclose(self) -> None:
# Note that this method unilaterally closes the connection, and does
Expand Down Expand Up @@ -513,12 +550,16 @@ def is_available(self) -> bool:
self._state != HTTPConnectionState.CLOSED
and not self._connection_error
and not self._used_all_stream_ids
and len(self._events) < self.max_concurrent_requests()
and not (
self._h2_state.state_machine.state
== h2.connection.ConnectionState.CLOSED
)
)

def max_concurrent_requests(self) -> int:
return self._max_streams if self._sent_connection_init else 1

def has_expired(self) -> bool:
now = time.monotonic()
return self._expire_at is not None and now > self._expire_at
Expand Down
12 changes: 12 additions & 0 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,15 @@ def __init__(
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
proxy_ssl_context: ssl.SSLContext | None = None,
on_capacity_update: typing.Callable[[], typing.Any] | None = None,
) -> None:
self._connection = AsyncHTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
ssl_context=proxy_ssl_context,
on_capacity_update=on_capacity_update,
)
self._proxy_origin = proxy_origin
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
Expand Down Expand Up @@ -217,6 +219,9 @@ def info(self) -> str:
def is_available(self) -> bool:
return self._connection.is_available()

def max_concurrent_requests(self) -> int: # pragma: nocover
return self._connection.max_concurrent_requests()

def has_expired(self) -> bool:
return self._connection.has_expired()

Expand All @@ -243,13 +248,15 @@ def __init__(
http2: bool = False,
network_backend: AsyncNetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
on_capacity_update: typing.Callable[[], typing.Any] | None = None,
) -> None:
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
ssl_context=proxy_ssl_context,
on_capacity_update=on_capacity_update,
)
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
Expand All @@ -259,6 +266,7 @@ def __init__(
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._on_capacity_update = on_capacity_update
self._connect_lock = AsyncLock()
self._connected = False

Expand Down Expand Up @@ -331,6 +339,7 @@ async def handle_async_request(self, request: Request) -> Response:
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
on_capacity_update=self._on_capacity_update,
)
else:
self._connection = AsyncHTTP11Connection(
Expand All @@ -354,6 +363,9 @@ def info(self) -> str:
def is_available(self) -> bool:
return self._connection.is_available()

def max_concurrent_requests(self) -> int: # pragma: nocover
return self._connection.max_concurrent_requests()

def has_expired(self) -> bool:
return self._connection.has_expired()

Expand Down
Loading
Loading