From cb588311723e93f493c021c9c78405c96f70d76b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Ercolanelli?= Date: Wed, 26 Nov 2025 18:02:26 +0100 Subject: [PATCH] fix: use a single client for handshake and queries --- src/altertable_flightsql/client.py | 51 +++++++++++------------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/src/altertable_flightsql/client.py b/src/altertable_flightsql/client.py index 40dc617..9735fe9 100644 --- a/src/altertable_flightsql/client.py +++ b/src/altertable_flightsql/client.py @@ -57,35 +57,32 @@ class IngestIncrementalOptions: class BearerAuthMiddleware(flight.ClientMiddleware): """Client middleware that adds Bearer token authentication to all requests.""" - def __init__(self, token: bytes): - """ - Initialize the middleware with a Bearer token. - - Args: - token: The Bearer token - """ - self._token = token + def __init__(self, factory: "BearerAuthMiddlewareFactory"): + self._factory = factory def sending_headers(self): """A callback before headers are sent.""" - return {b"authorization": self._token} + headers = {} + + if self._factory.token: + headers[b"authorization"] = self._factory.token + + return headers + + def received_headers(self, headers): + if token := headers.get("authorization"): + self._factory.token = token class BearerAuthMiddlewareFactory(flight.ClientMiddlewareFactory): """Factory for creating Bearer authentication middleware.""" - def __init__(self, token: bytes): - """ - Initialize the factory with credentials. - - Args: - token: The Bearer token - """ - self._token = token + def __init__(self): + self.token = None def start_call(self, info): """Create middleware instance for a new call.""" - return BearerAuthMiddleware(self._token) + return BearerAuthMiddleware(self) class Client: @@ -139,9 +136,9 @@ def __init__( self._auto_commit = auto_commit self._transaction = None - token = self._handshake() - auth_middleware = BearerAuthMiddlewareFactory(token) + auth_middleware = BearerAuthMiddlewareFactory() self._client = flight.FlightClient(location, middleware=[auth_middleware]) + self._client.authenticate_basic_token(self._username, self._password) options = {} if catalog: @@ -153,17 +150,6 @@ def __init__( if options: self._set_options(options) - def _handshake(self) -> bytes: - """ - Perform authentication handshake with the server. - - Returns: - bytes: The Bearer token returned by the server - """ - with flight.FlightClient(self._location) as client: - header = client.authenticate_basic_token(self._username, self._password) - return header[1] - def _set_options(self, options: Mapping[str, sql_pb2.SessionOptionValue]): cmd = sql_pb2.SetSessionOptionsRequest(session_options=options) action = flight.Action("SetSessionOptions", _pack_command(cmd)) @@ -688,8 +674,7 @@ def _get_parameter_as_pyarrow( # Create record batch with positional parameters if len(parameters) != len(self._parameter_schema): raise ValueError( - f"Expected {len(self._parameter_schema)} parameters, " - f"but got {len(parameters)}" + f"Expected {len(self._parameter_schema)} parameters, but got {len(parameters)}" ) param_dict = { field.name: [value] for field, value in zip(self._parameter_schema, parameters)