From 485365547a2ed5d2b4fc66813bda9340026dfb2e Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Fri, 20 Mar 2026 21:11:34 -0600 Subject: [PATCH] Add support for connector factory in connect args This commit adds support for providing a connector factory, a callable that provides full control over how the network connection to the PostgreSQL server is established. When specified, this factory replaces the default connection logic. The callable receives the following arguments: - *proto_factory* - a callable that returns the asyncpg protocol instance - *host* - the target hostname (positional) - *port* - the target port (positional) - *loop* - the event loop (keyword argument) - *ssl* - the SSL context, or ``None`` (keyword argument) The callable must return an awaitable that resolves to a `(transport, protocol)` tuple, compatible with `asyncio.loop.create_connection`. This is useful for scenarios such as connecting through a proxy, establishing an SSH tunnel, or performing custom socket setup before the PostgreSQL protocol begins. Usage looks like this: ``` async def my_connector(proto_factory, host, port, *, loop, ssl): tunnel_sock = await open_ssh_tunnel(host, port) return await loop.create_connection( proto_factory, sock=tunnel_sock, ssl=ssl) conn = await asyncpg.connect( connector_factory=my_connector, host='db.example.com', user='postgres', ) ``` Fixes #1054. --- asyncpg/connect_utils.py | 20 +++++++++++-- asyncpg/connection.py | 40 ++++++++++++++++++++++++- tests/test_connect.py | 63 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 4 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 07c4fdde..a5d8c1c2 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -66,6 +66,7 @@ class SSLNegotiation(compat.StrEnum): 'target_session_attrs', 'krbsrvname', 'gsslib', + 'connector_factory', ]) @@ -854,7 +855,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, sslmode=sslmode, ssl_negotiation=sslneg, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + krbsrvname=krbsrvname, gsslib=gsslib, connector_factory=None) return addrs, params @@ -866,7 +867,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cacheable_statement_size, ssl, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib, - service, servicefile): + service, servicefile, connector_factory=None): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -899,6 +900,15 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, krbsrvname=krbsrvname, gsslib=gsslib, service=service, servicefile=servicefile) + if connector_factory is not None: + if not callable(connector_factory): + raise TypeError( + "connector_factory is expected to be a callable, got {!r}".format( + type(connector_factory) + ) + ) + params = params._replace(connector_factory=connector_factory) + config = _ClientConfiguration( command_timeout=command_timeout, statement_cache_size=statement_cache_size, @@ -1078,7 +1088,11 @@ async def __connect_addr( proto_factory = lambda: protocol.Protocol( addr, connected, params, record_class, loop) - if isinstance(addr, str): + if params.connector_factory is not None: + connector = params.connector_factory( + proto_factory, *addr, loop=loop, ssl=params.ssl) + + elif isinstance(addr, str): # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 71fb04f8..5ba2b418 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2099,7 +2099,8 @@ async def connect(dsn=None, *, server_settings=None, target_session_attrs=None, krbsrvname=None, - gsslib=None): + gsslib=None, + connector_factory=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2343,6 +2344,42 @@ async def connect(dsn=None, *, GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi' or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise. + :param callable connector_factory: + A callable that provides full control over how the network connection + to the PostgreSQL server is established. When specified, this + factory replaces the default connection logic. The callable receives + the following arguments: + + - ``proto_factory`` - a callable that returns the asyncpg protocol + instance + - ``host`` - the target hostname (positional) + - ``port`` - the target port (positional) + - ``loop`` - the event loop (keyword argument) + - ``ssl`` - the SSL context, or ``None`` (keyword argument) + + The callable must return an awaitable that resolves to a + ``(transport, protocol)`` tuple, compatible with + :meth:`asyncio.loop.create_connection`. + + This is useful for scenarios such as connecting through a proxy, + establishing an SSH tunnel, or performing custom socket setup + before the PostgreSQL protocol begins. + + Example: + + .. code-block:: python + + async def my_connector(proto_factory, host, port, *, loop, ssl): + tunnel_sock = await open_ssh_tunnel(host, port) + return await loop.create_connection( + proto_factory, sock=tunnel_sock, ssl=ssl) + + conn = await asyncpg.connect( + connector_factory=my_connector, + host='db.example.com', + user='postgres', + ) + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2463,6 +2500,7 @@ async def connect(dsn=None, *, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, + connector_factory=connector_factory, ) diff --git a/tests/test_connect.py b/tests/test_connect.py index 955fb825..4950586d 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1207,6 +1207,8 @@ def run_testcase(self, testcase): # Avoid the hassle of specifying gsslib # unless explicitly tested for params.pop('gsslib', None) + if 'connector_factory' not in expected[1]: + params.pop('connector_factory', None) self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) @@ -1792,6 +1794,67 @@ async def test_connection_no_home_dir(self): user='ssl_user', ssl='verify-full') + async def test_connection_connector_factory(self): + conn_spec = self.get_connection_spec() + host = conn_spec.get('host') + port = conn_spec.get('port') + + factory_called = False + + async def connector_factory(proto_factory, host, port, *, loop, ssl): + nonlocal factory_called + factory_called = True + sock = socket.create_connection((host, port)) + sock.setblocking(False) + return await loop.create_connection( + proto_factory, sock=sock, ssl=ssl) + + con = await asyncpg.connect( + host=host, + port=port, + user=conn_spec.get('user'), + database=conn_spec.get('database'), + ssl=False, + connector_factory=connector_factory, + ) + try: + self.assertTrue(factory_called) + self.assertEqual(await con.fetchval('SELECT 42'), 42) + finally: + await con.close() + + async def test_connection_connector_factory_with_pool(self): + conn_spec = self.get_connection_spec() + host = conn_spec.get('host') + port = conn_spec.get('port') + + factory_called = False + + async def connector_factory(proto_factory, host, port, *, loop, ssl): + nonlocal factory_called + factory_called = True + sock = socket.create_connection((host, port)) + sock.setblocking(False) + return await loop.create_connection( + proto_factory, sock=sock, ssl=ssl) + + pool = await asyncpg.create_pool( + host=host, + port=port, + user=conn_spec.get('user'), + database=conn_spec.get('database'), + ssl=False, + connector_factory=connector_factory, + min_size=1, + max_size=1, + ) + try: + self.assertTrue(factory_called) + async with pool.acquire() as con: + self.assertEqual(await con.fetchval('SELECT 42'), 42) + finally: + await pool.close() + class BaseTestSSLConnection(tb.ConnectedTestCase): @classmethod