diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 07c4fdde..a5d8c1c2 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -66,6 +66,7 @@ class SSLNegotiation(compat.StrEnum): 'target_session_attrs', 'krbsrvname', 'gsslib', + 'connector_factory', ]) @@ -854,7 +855,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, sslmode=sslmode, ssl_negotiation=sslneg, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + krbsrvname=krbsrvname, gsslib=gsslib, connector_factory=None) return addrs, params @@ -866,7 +867,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cacheable_statement_size, ssl, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib, - service, servicefile): + service, servicefile, connector_factory=None): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -899,6 +900,15 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, krbsrvname=krbsrvname, gsslib=gsslib, service=service, servicefile=servicefile) + if connector_factory is not None: + if not callable(connector_factory): + raise TypeError( + "connector_factory is expected to be a callable, got {!r}".format( + type(connector_factory) + ) + ) + params = params._replace(connector_factory=connector_factory) + config = _ClientConfiguration( command_timeout=command_timeout, statement_cache_size=statement_cache_size, @@ -1078,7 +1088,11 @@ async def __connect_addr( proto_factory = lambda: protocol.Protocol( addr, connected, params, record_class, loop) - if isinstance(addr, str): + if params.connector_factory is not None: + connector = params.connector_factory( + proto_factory, *addr, loop=loop, ssl=params.ssl) + + elif isinstance(addr, str): # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 71fb04f8..5ba2b418 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2099,7 +2099,8 @@ async def connect(dsn=None, *, server_settings=None, target_session_attrs=None, krbsrvname=None, - gsslib=None): + gsslib=None, + connector_factory=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2343,6 +2344,42 @@ async def connect(dsn=None, *, GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi' or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise. + :param callable connector_factory: + A callable that provides full control over how the network connection + to the PostgreSQL server is established. When specified, this + factory replaces the default connection logic. The callable receives + the following arguments: + + - ``proto_factory`` - a callable that returns the asyncpg protocol + instance + - ``host`` - the target hostname (positional) + - ``port`` - the target port (positional) + - ``loop`` - the event loop (keyword argument) + - ``ssl`` - the SSL context, or ``None`` (keyword argument) + + The callable must return an awaitable that resolves to a + ``(transport, protocol)`` tuple, compatible with + :meth:`asyncio.loop.create_connection`. + + This is useful for scenarios such as connecting through a proxy, + establishing an SSH tunnel, or performing custom socket setup + before the PostgreSQL protocol begins. + + Example: + + .. code-block:: python + + async def my_connector(proto_factory, host, port, *, loop, ssl): + tunnel_sock = await open_ssh_tunnel(host, port) + return await loop.create_connection( + proto_factory, sock=tunnel_sock, ssl=ssl) + + conn = await asyncpg.connect( + connector_factory=my_connector, + host='db.example.com', + user='postgres', + ) + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2463,6 +2500,7 @@ async def connect(dsn=None, *, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, + connector_factory=connector_factory, ) diff --git a/tests/test_connect.py b/tests/test_connect.py index 955fb825..4950586d 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1207,6 +1207,8 @@ def run_testcase(self, testcase): # Avoid the hassle of specifying gsslib # unless explicitly tested for params.pop('gsslib', None) + if 'connector_factory' not in expected[1]: + params.pop('connector_factory', None) self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) @@ -1792,6 +1794,67 @@ async def test_connection_no_home_dir(self): user='ssl_user', ssl='verify-full') + async def test_connection_connector_factory(self): + conn_spec = self.get_connection_spec() + host = conn_spec.get('host') + port = conn_spec.get('port') + + factory_called = False + + async def connector_factory(proto_factory, host, port, *, loop, ssl): + nonlocal factory_called + factory_called = True + sock = socket.create_connection((host, port)) + sock.setblocking(False) + return await loop.create_connection( + proto_factory, sock=sock, ssl=ssl) + + con = await asyncpg.connect( + host=host, + port=port, + user=conn_spec.get('user'), + database=conn_spec.get('database'), + ssl=False, + connector_factory=connector_factory, + ) + try: + self.assertTrue(factory_called) + self.assertEqual(await con.fetchval('SELECT 42'), 42) + finally: + await con.close() + + async def test_connection_connector_factory_with_pool(self): + conn_spec = self.get_connection_spec() + host = conn_spec.get('host') + port = conn_spec.get('port') + + factory_called = False + + async def connector_factory(proto_factory, host, port, *, loop, ssl): + nonlocal factory_called + factory_called = True + sock = socket.create_connection((host, port)) + sock.setblocking(False) + return await loop.create_connection( + proto_factory, sock=sock, ssl=ssl) + + pool = await asyncpg.create_pool( + host=host, + port=port, + user=conn_spec.get('user'), + database=conn_spec.get('database'), + ssl=False, + connector_factory=connector_factory, + min_size=1, + max_size=1, + ) + try: + self.assertTrue(factory_called) + async with pool.acquire() as con: + self.assertEqual(await con.fetchval('SELECT 42'), 42) + finally: + await pool.close() + class BaseTestSSLConnection(tb.ConnectedTestCase): @classmethod