Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class SSLNegotiation(compat.StrEnum):
'target_session_attrs',
'krbsrvname',
'gsslib',
'connector_factory',
])


Expand Down Expand Up @@ -854,7 +855,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
sslmode=sslmode, ssl_negotiation=sslneg,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, connector_factory=None)

return addrs, params

Expand All @@ -866,7 +867,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib,
service, servicefile):
service, servicefile, connector_factory=None):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -899,6 +900,15 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
krbsrvname=krbsrvname, gsslib=gsslib,
service=service, servicefile=servicefile)

if connector_factory is not None:
if not callable(connector_factory):
raise TypeError(
"connector_factory is expected to be a callable, got {!r}".format(
type(connector_factory)
)
)
params = params._replace(connector_factory=connector_factory)

config = _ClientConfiguration(
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
Expand Down Expand Up @@ -1078,7 +1088,11 @@ async def __connect_addr(
proto_factory = lambda: protocol.Protocol(
addr, connected, params, record_class, loop)

if isinstance(addr, str):
if params.connector_factory is not None:
connector = params.connector_factory(
proto_factory, *addr, loop=loop, ssl=params.ssl)

elif isinstance(addr, str):
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)

Expand Down
40 changes: 39 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,7 +2099,8 @@ async def connect(dsn=None, *,
server_settings=None,
target_session_attrs=None,
krbsrvname=None,
gsslib=None):
gsslib=None,
connector_factory=None):
r"""A coroutine to establish a connection to a PostgreSQL server.

The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2343,6 +2344,42 @@ async def connect(dsn=None, *,
GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi'
or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise.

:param callable connector_factory:
A callable that provides full control over how the network connection
to the PostgreSQL server is established. When specified, this
factory replaces the default connection logic. The callable receives
the following arguments:

- ``proto_factory`` - a callable that returns the asyncpg protocol
instance
- ``host`` - the target hostname (positional)
- ``port`` - the target port (positional)
- ``loop`` - the event loop (keyword argument)
- ``ssl`` - the SSL context, or ``None`` (keyword argument)

The callable must return an awaitable that resolves to a
``(transport, protocol)`` tuple, compatible with
:meth:`asyncio.loop.create_connection`.

This is useful for scenarios such as connecting through a proxy,
establishing an SSH tunnel, or performing custom socket setup
before the PostgreSQL protocol begins.

Example:

.. code-block:: python

async def my_connector(proto_factory, host, port, *, loop, ssl):
tunnel_sock = await open_ssh_tunnel(host, port)
return await loop.create_connection(
proto_factory, sock=tunnel_sock, ssl=ssl)

conn = await asyncpg.connect(
connector_factory=my_connector,
host='db.example.com',
user='postgres',
)

:return: A :class:`~asyncpg.connection.Connection` instance.

Example:
Expand Down Expand Up @@ -2463,6 +2500,7 @@ async def connect(dsn=None, *,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname,
gsslib=gsslib,
connector_factory=connector_factory,
)


Expand Down
63 changes: 63 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,8 @@ def run_testcase(self, testcase):
# Avoid the hassle of specifying gsslib
# unless explicitly tested for
params.pop('gsslib', None)
if 'connector_factory' not in expected[1]:
params.pop('connector_factory', None)

self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))

Expand Down Expand Up @@ -1792,6 +1794,67 @@ async def test_connection_no_home_dir(self):
user='ssl_user',
ssl='verify-full')

async def test_connection_connector_factory(self):
conn_spec = self.get_connection_spec()
host = conn_spec.get('host')
port = conn_spec.get('port')

factory_called = False

async def connector_factory(proto_factory, host, port, *, loop, ssl):
nonlocal factory_called
factory_called = True
sock = socket.create_connection((host, port))
sock.setblocking(False)
return await loop.create_connection(
proto_factory, sock=sock, ssl=ssl)

con = await asyncpg.connect(
host=host,
port=port,
user=conn_spec.get('user'),
database=conn_spec.get('database'),
ssl=False,
connector_factory=connector_factory,
)
try:
self.assertTrue(factory_called)
self.assertEqual(await con.fetchval('SELECT 42'), 42)
finally:
await con.close()

async def test_connection_connector_factory_with_pool(self):
conn_spec = self.get_connection_spec()
host = conn_spec.get('host')
port = conn_spec.get('port')

factory_called = False

async def connector_factory(proto_factory, host, port, *, loop, ssl):
nonlocal factory_called
factory_called = True
sock = socket.create_connection((host, port))
sock.setblocking(False)
return await loop.create_connection(
proto_factory, sock=sock, ssl=ssl)

pool = await asyncpg.create_pool(
host=host,
port=port,
user=conn_spec.get('user'),
database=conn_spec.get('database'),
ssl=False,
connector_factory=connector_factory,
min_size=1,
max_size=1,
)
try:
self.assertTrue(factory_called)
async with pool.acquire() as con:
self.assertEqual(await con.fetchval('SELECT 42'), 42)
finally:
await pool.close()


class BaseTestSSLConnection(tb.ConnectedTestCase):
@classmethod
Expand Down