Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,10 @@ def connect(
route_to_leader_enabled=True,
database_role=None,
experimental_host=None,
use_plain_text=False,
ca_certificate=None,
client_certificate=None,
client_key=None,
**kwargs,
):
"""Creates a connection to a Google Cloud Spanner database.
Expand Down Expand Up @@ -789,6 +793,28 @@ def connect(
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.

:type experimental_host: str
:param experimental_host: (Optional) The endpoint for a spanner experimental host deployment.
This is intended only for experimental host spanner endpoints.

:type use_plain_text: bool
:param use_plain_text: (Optional) Whether to use plain text for the connection.
This is intended only for experimental host spanner endpoints.
If not set, the default behavior is to use TLS.

:type ca_certificate: str
:param ca_certificate: (Optional) The path to the CA certificate file used for TLS connection.
This is intended only for experimental host spanner endpoints.
This is mandatory if the experimental_host requires a TLS connection.
:type client_certificate: str
:param client_certificate: (Optional) The path to the client certificate file used for mTLS connection.
This is intended only for experimental host spanner endpoints.
This is mandatory if the experimental_host requires an mTLS connection.
:type client_key: str
:param client_key: (Optional) The path to the client key file used for mTLS connection.
This is intended only for experimental host spanner endpoints.
This is mandatory if the experimental_host requires an mTLS connection.
"""
if client is None:
client_info = ClientInfo(
Expand Down Expand Up @@ -817,6 +843,10 @@ def connect(
client_info=client_info,
route_to_leader_enabled=route_to_leader_enabled,
client_options=client_options,
use_plain_text=use_plain_text,
ca_certificate=ca_certificate,
client_certificate=client_certificate,
client_key=client_key,
)
else:
if project is not None and client.project != project:
Expand Down
62 changes: 62 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,65 @@ def _merge_Transaction_Options(

# Convert protobuf object back into a TransactionOptions instance
return TransactionOptions(merged_pb)


def _create_experimental_host_transport(
transport_factory,
experimental_host,
use_plain_text,
ca_certificate,
client_certificate,
client_key,
interceptors=None,
):
"""Creates an experimental host transport for Spanner.

Args:
transport_factory (type): The transport class to instantiate (e.g.
`SpannerGrpcTransport`).
experimental_host (str): The endpoint for the experimental host.
use_plain_text (bool): Whether to use a plain text (insecure) connection.
ca_certificate (str): Path to the CA certificate file for TLS.
client_certificate (str): Path to the client certificate file for mTLS.
client_key (str): Path to the client key file for mTLS.
interceptors (list): Optional list of interceptors to add to the channel.

Returns:
object: An instance of the transport class created by `transport_factory`.

Raises:
ValueError: If TLS/mTLS configuration is invalid.
"""
import grpc
from google.auth.credentials import AnonymousCredentials

channel = None
if use_plain_text:
channel = grpc.insecure_channel(target=experimental_host)
elif ca_certificate:
with open(ca_certificate, "rb") as f:
ca_cert = f.read()
if client_certificate and client_key:
with open(client_certificate, "rb") as f:
client_cert = f.read()
with open(client_key, "rb") as f:
private_key = f.read()
ssl_creds = grpc.ssl_channel_credentials(
root_certificates=ca_cert,
private_key=private_key,
certificate_chain=client_cert,
)
elif client_certificate or client_key:
raise ValueError(
"Both client_certificate and client_key must be provided for mTLS connection"
)
else:
ssl_creds = grpc.ssl_channel_credentials(root_certificates=ca_cert)
channel = grpc.secure_channel(experimental_host, ssl_creds)
else:
raise ValueError(
"TLS/mTLS connection requires ca_certificate to be set for experimental_host"
)
if interceptors is not None:
channel = grpc.intercept_channel(channel, *interceptors)
return transport_factory(channel=channel, credentials=AnonymousCredentials())
60 changes: 53 additions & 7 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
from google.cloud.spanner_v1 import __version__
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import DefaultTransactionOptions
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import (
_create_experimental_host_transport,
_merge_query_options,
)
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1.metrics.constants import (
Expand Down Expand Up @@ -186,6 +189,30 @@ class Client(ClientWithProject):

:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
and ``admin`` are :data:`True`

:type use_plain_text: bool
:param use_plain_text: (Optional) Whether to use plain text for the connection.
This is intended only for experimental host spanner endpoints.
If set, this will override the `api_endpoint` in `client_options`.
If not set, the default behavior is to use TLS.

:type ca_certificate: str
:param ca_certificate: (Optional) The path to the CA certificate file used for TLS connection.
This is intended only for experimental host spanner endpoints.
If set, this will override the `api_endpoint` in `client_options`.
This is mandatory if the experimental_host requires a TLS connection.

:type client_certificate: str
:param client_certificate: (Optional) The path to the client certificate file used for mTLS connection.
This is intended only for experimental host spanner endpoints.
If set, this will override the `api_endpoint` in `client_options`.
This is mandatory if the experimental_host requires a mTLS connection.

:type client_key: str
:param client_key: (Optional) The path to the client key file used for mTLS connection.
This is intended only for experimental host spanner endpoints.
If set, this will override the `api_endpoint` in `client_options`.
This is mandatory if the experimental_host requires a mTLS connection.
"""

_instance_admin_api = None
Expand All @@ -210,6 +237,10 @@ def __init__(
default_transaction_options: Optional[DefaultTransactionOptions] = None,
experimental_host=None,
disable_builtin_metrics=False,
use_plain_text=False,
ca_certificate=None,
client_certificate=None,
client_key=None,
):
self._emulator_host = _get_spanner_emulator_host()
self._experimental_host = experimental_host
Expand All @@ -224,6 +255,12 @@ def __init__(
if self._emulator_host:
credentials = AnonymousCredentials()
elif self._experimental_host:
# For all experimental host endpoints project is default
project = "default"
self._use_plain_text = use_plain_text
self._ca_certificate = ca_certificate
self._client_certificate = client_certificate
self._client_key = client_key
credentials = AnonymousCredentials()
elif isinstance(credentials, AnonymousCredentials):
self._emulator_host = self._client_options.api_endpoint
Expand Down Expand Up @@ -259,7 +296,7 @@ def __init__(
):
meter_provider = metrics.NoOpMeterProvider()
try:
if not _get_spanner_emulator_host():
if not _get_spanner_emulator_host() and not self._experimental_host:
meter_provider = MeterProvider(
metric_readers=[
PeriodicExportingMetricReader(
Expand Down Expand Up @@ -339,8 +376,13 @@ def instance_admin_api(self):
transport=transport,
)
elif self._experimental_host:
transport = InstanceAdminGrpcTransport(
channel=grpc.insecure_channel(target=self._experimental_host)
transport = _create_experimental_host_transport(
InstanceAdminGrpcTransport,
self._experimental_host,
self._use_plain_text,
self._ca_certificate,
self._client_certificate,
self._client_key,
)
self._instance_admin_api = InstanceAdminClient(
client_info=self._client_info,
Expand Down Expand Up @@ -369,8 +411,13 @@ def database_admin_api(self):
transport=transport,
)
elif self._experimental_host:
transport = DatabaseAdminGrpcTransport(
channel=grpc.insecure_channel(target=self._experimental_host)
transport = _create_experimental_host_transport(
DatabaseAdminGrpcTransport,
self._experimental_host,
self._use_plain_text,
self._ca_certificate,
self._client_certificate,
self._client_key,
)
self._database_admin_api = DatabaseAdminClient(
client_info=self._client_info,
Expand Down Expand Up @@ -517,7 +564,6 @@ def instance(
self._emulator_host,
labels,
processing_units,
self._experimental_host,
)

def list_instances(self, filter_="", page_size=None):
Expand Down
18 changes: 11 additions & 7 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_metadata_with_request_id,
_create_experimental_host_transport,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
Expand Down Expand Up @@ -203,11 +204,9 @@ def __init__(

self._pool = pool
pool.bind(self)
is_experimental_host = self._instance.experimental_host is not None
self._experimental_host = self._instance._client._experimental_host

self._sessions_manager = DatabaseSessionsManager(
self, pool, is_experimental_host
)
self._sessions_manager = DatabaseSessionsManager(self, pool)

@classmethod
def from_pb(cls, database_pb, instance, pool=None):
Expand Down Expand Up @@ -452,9 +451,14 @@ def spanner_api(self):
client_info=client_info, transport=transport
)
return self._spanner_api
if self._instance.experimental_host is not None:
transport = SpannerGrpcTransport(
channel=grpc.insecure_channel(self._instance.experimental_host)
if self._experimental_host is not None:
transport = _create_experimental_host_transport(
SpannerGrpcTransport,
self._experimental_host,
self._instance._client._use_plain_text,
self._instance._client._ca_certificate,
self._instance._client._client_certificate,
self._instance._client._client_key,
)
self._spanner_api = SpannerClient(
client_info=client_info,
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/spanner_v1/database_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ class DatabaseSessionsManager(object):
_MAINTENANCE_THREAD_POLLING_INTERVAL = timedelta(minutes=10)
_MAINTENANCE_THREAD_REFRESH_INTERVAL = timedelta(days=7)

def __init__(self, database, pool, is_experimental_host: bool = False):
def __init__(self, database, pool):
self._database = database
self._pool = pool
self._is_experimental_host = is_experimental_host

# Declare multiplexed session attributes. When a multiplexed session for the
# database session manager is created, a maintenance thread is initialized to
Expand All @@ -89,7 +88,8 @@ def get_session(self, transaction_type: TransactionType) -> Session:

session = (
self._get_multiplexed_session()
if self._use_multiplexed(transaction_type) or self._is_experimental_host
if self._use_multiplexed(transaction_type)
or self._database._experimental_host is not None
else self._pool.get()
)

Expand Down
2 changes: 0 additions & 2 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def __init__(
emulator_host=None,
labels=None,
processing_units=None,
experimental_host=None,
):
self.instance_id = instance_id
self._client = client
Expand All @@ -143,7 +142,6 @@ def __init__(
self._node_count = processing_units // PROCESSING_UNITS_PER_NODE
self.display_name = display_name or instance_id
self.emulator_host = emulator_host
self.experimental_host = experimental_host
if labels is None:
labels = {}
self.labels = labels
Expand Down
15 changes: 11 additions & 4 deletions google/cloud/spanner_v1/testing/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import google.auth.credentials
from google.cloud.spanner_admin_database_v1 import DatabaseDialect
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1._helpers import _create_experimental_host_transport
from google.cloud.spanner_v1.database import Database, SPANNER_DATA_SCOPE
from google.cloud.spanner_v1.services.spanner.transports import (
SpannerGrpcTransport,
Expand Down Expand Up @@ -86,12 +87,18 @@ def spanner_api(self):
transport=transport,
)
return self._spanner_api
if self._instance.experimental_host is not None:
channel = grpc.insecure_channel(self._instance.experimental_host)
if self._experimental_host is not None:
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
self._interceptors.append(self._x_goog_request_id_interceptor)
channel = grpc.intercept_channel(channel, *self._interceptors)
transport = SpannerGrpcTransport(channel=channel)
transport = _create_experimental_host_transport(
SpannerGrpcTransport,
self._experimental_host,
self._instance._client._use_plain_text,
self._instance._client._ca_certificate,
self._instance._client._client_certificate,
self._instance._client._client_key,
self._interceptors,
)
self._spanner_api = SpannerClient(
client_info=client_info,
transport=transport,
Expand Down
9 changes: 8 additions & 1 deletion tests/system/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@
EXPERIMENTAL_HOST = os.getenv(USE_EXPERIMENTAL_HOST_ENVVAR)
USE_EXPERIMENTAL_HOST = EXPERIMENTAL_HOST is not None

EXPERIMENTAL_HOST_PROJECT = "default"
CA_CERTIFICATE_ENVVAR = "CA_CERTIFICATE"
CA_CERTIFICATE = os.getenv(CA_CERTIFICATE_ENVVAR)
CLIENT_CERTIFICATE_ENVVAR = "CLIENT_CERTIFICATE"
CLIENT_CERTIFICATE = os.getenv(CLIENT_CERTIFICATE_ENVVAR)
CLIENT_KEY_ENVVAR = "CLIENT_KEY"
CLIENT_KEY = os.getenv(CLIENT_KEY_ENVVAR)
USE_PLAIN_TEXT = CA_CERTIFICATE is None

EXPERIMENTAL_HOST_INSTANCE = "default"

DDL_STATEMENTS = (
Expand Down
5 changes: 4 additions & 1 deletion tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def spanner_client():

credentials = AnonymousCredentials()
return spanner_v1.Client(
project=_helpers.EXPERIMENTAL_HOST_PROJECT,
use_plain_text=_helpers.USE_PLAIN_TEXT,
ca_certificate=_helpers.CA_CERTIFICATE,
client_certificate=_helpers.CLIENT_CERTIFICATE,
client_key=_helpers.CLIENT_KEY,
credentials=credentials,
experimental_host=_helpers.EXPERIMENTAL_HOST,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,10 @@ def test_user_agent(self, shared_instance, dbapi_database):
experimental_host=_helpers.EXPERIMENTAL_HOST
if _helpers.USE_EXPERIMENTAL_HOST
else None,
use_plain_text=_helpers.USE_PLAIN_TEXT,
ca_certificate=_helpers.CA_CERTIFICATE,
client_certificate=_helpers.CLIENT_CERTIFICATE,
client_key=_helpers.CLIENT_KEY,
)
assert (
conn.instance._client._client_info.user_agent
Expand Down
Loading
Loading