Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 18 additions & 33 deletions src/altertable_flightsql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,32 @@ class IngestIncrementalOptions:
class BearerAuthMiddleware(flight.ClientMiddleware):
"""Client middleware that adds Bearer token authentication to all requests."""

def __init__(self, token: bytes):
"""
Initialize the middleware with a Bearer token.

Args:
token: The Bearer token
"""
self._token = token
def __init__(self, factory: "BearerAuthMiddlewareFactory"):
self._factory = factory

def sending_headers(self):
"""A callback before headers are sent."""
return {b"authorization": self._token}
headers = {}

if self._factory.token:
headers[b"authorization"] = self._factory.token

return headers

def received_headers(self, headers):
if token := headers.get("authorization"):
self._factory.token = token


class BearerAuthMiddlewareFactory(flight.ClientMiddlewareFactory):
"""Factory for creating Bearer authentication middleware."""

def __init__(self, token: bytes):
"""
Initialize the factory with credentials.

Args:
token: The Bearer token
"""
self._token = token
def __init__(self):
self.token = None

def start_call(self, info):
"""Create middleware instance for a new call."""
return BearerAuthMiddleware(self._token)
return BearerAuthMiddleware(self)


class Client:
Expand Down Expand Up @@ -139,9 +136,9 @@ def __init__(
self._auto_commit = auto_commit
self._transaction = None

token = self._handshake()
auth_middleware = BearerAuthMiddlewareFactory(token)
auth_middleware = BearerAuthMiddlewareFactory()
self._client = flight.FlightClient(location, middleware=[auth_middleware])
self._client.authenticate_basic_token(self._username, self._password)

options = {}
if catalog:
Expand All @@ -153,17 +150,6 @@ def __init__(
if options:
self._set_options(options)

def _handshake(self) -> bytes:
"""
Perform authentication handshake with the server.

Returns:
bytes: The Bearer token returned by the server
"""
with flight.FlightClient(self._location) as client:
header = client.authenticate_basic_token(self._username, self._password)
return header[1]

def _set_options(self, options: Mapping[str, sql_pb2.SessionOptionValue]):
cmd = sql_pb2.SetSessionOptionsRequest(session_options=options)
action = flight.Action("SetSessionOptions", _pack_command(cmd))
Expand Down Expand Up @@ -688,8 +674,7 @@ def _get_parameter_as_pyarrow(
# Create record batch with positional parameters
if len(parameters) != len(self._parameter_schema):
raise ValueError(
f"Expected {len(self._parameter_schema)} parameters, "
f"but got {len(parameters)}"
f"Expected {len(self._parameter_schema)} parameters, but got {len(parameters)}"
)
param_dict = {
field.name: [value] for field, value in zip(self._parameter_schema, parameters)
Expand Down