diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 111bc4cc1b..871eb152da 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -736,6 +736,10 @@ def connect( route_to_leader_enabled=True, database_role=None, experimental_host=None, + use_plain_text=False, + ca_certificate=None, + client_certificate=None, + client_key=None, **kwargs, ): """Creates a connection to a Google Cloud Spanner database. @@ -789,6 +793,28 @@ def connect( :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` :returns: Connection object associated with the given Google Cloud Spanner resource. + + :type experimental_host: str + :param experimental_host: (Optional) The endpoint for a spanner experimental host deployment. + This is intended only for experimental host spanner endpoints. + + :type use_plain_text: bool + :param use_plain_text: (Optional) Whether to use plain text for the connection. + This is intended only for experimental host spanner endpoints. + If not set, the default behavior is to use TLS. + + :type ca_certificate: str + :param ca_certificate: (Optional) The path to the CA certificate file used for TLS connection. + This is intended only for experimental host spanner endpoints. + This is mandatory if the experimental_host requires a TLS connection. + :type client_certificate: str + :param client_certificate: (Optional) The path to the client certificate file used for mTLS connection. + This is intended only for experimental host spanner endpoints. + This is mandatory if the experimental_host requires an mTLS connection. + :type client_key: str + :param client_key: (Optional) The path to the client key file used for mTLS connection. + This is intended only for experimental host spanner endpoints. + This is mandatory if the experimental_host requires an mTLS connection. """ if client is None: client_info = ClientInfo( @@ -817,6 +843,10 @@ def connect( client_info=client_info, route_to_leader_enabled=route_to_leader_enabled, client_options=client_options, + use_plain_text=use_plain_text, + ca_certificate=ca_certificate, + client_certificate=client_certificate, + client_key=client_key, ) else: if project is not None and client.project != project: diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 8a200fe812..826102cde0 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -803,3 +803,65 @@ def _merge_Transaction_Options( # Convert protobuf object back into a TransactionOptions instance return TransactionOptions(merged_pb) + + +def _create_experimental_host_transport( + transport_factory, + experimental_host, + use_plain_text, + ca_certificate, + client_certificate, + client_key, + interceptors=None, +): + """Creates an experimental host transport for Spanner. + + Args: + transport_factory (type): The transport class to instantiate (e.g. + `SpannerGrpcTransport`). + experimental_host (str): The endpoint for the experimental host. + use_plain_text (bool): Whether to use a plain text (insecure) connection. + ca_certificate (str): Path to the CA certificate file for TLS. + client_certificate (str): Path to the client certificate file for mTLS. + client_key (str): Path to the client key file for mTLS. + interceptors (list): Optional list of interceptors to add to the channel. + + Returns: + object: An instance of the transport class created by `transport_factory`. + + Raises: + ValueError: If TLS/mTLS configuration is invalid. + """ + import grpc + from google.auth.credentials import AnonymousCredentials + + channel = None + if use_plain_text: + channel = grpc.insecure_channel(target=experimental_host) + elif ca_certificate: + with open(ca_certificate, "rb") as f: + ca_cert = f.read() + if client_certificate and client_key: + with open(client_certificate, "rb") as f: + client_cert = f.read() + with open(client_key, "rb") as f: + private_key = f.read() + ssl_creds = grpc.ssl_channel_credentials( + root_certificates=ca_cert, + private_key=private_key, + certificate_chain=client_cert, + ) + elif client_certificate or client_key: + raise ValueError( + "Both client_certificate and client_key must be provided for mTLS connection" + ) + else: + ssl_creds = grpc.ssl_channel_credentials(root_certificates=ca_cert) + channel = grpc.secure_channel(experimental_host, ssl_creds) + else: + raise ValueError( + "TLS/mTLS connection requires ca_certificate to be set for experimental_host" + ) + if interceptors is not None: + channel = grpc.intercept_channel(channel, *interceptors) + return transport_factory(channel=channel, credentials=AnonymousCredentials()) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index 5f72905616..11f4d7834e 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -48,7 +48,10 @@ from google.cloud.spanner_v1 import __version__ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import DefaultTransactionOptions -from google.cloud.spanner_v1._helpers import _merge_query_options +from google.cloud.spanner_v1._helpers import ( + _create_experimental_host_transport, + _merge_query_options, +) from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.metrics.constants import ( @@ -186,6 +189,30 @@ class Client(ClientWithProject): :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` + + :type use_plain_text: bool + :param use_plain_text: (Optional) Whether to use plain text for the connection. + This is intended only for experimental host spanner endpoints. + If set, this will override the `api_endpoint` in `client_options`. + If not set, the default behavior is to use TLS. + + :type ca_certificate: str + :param ca_certificate: (Optional) The path to the CA certificate file used for TLS connection. + This is intended only for experimental host spanner endpoints. + If set, this will override the `api_endpoint` in `client_options`. + This is mandatory if the experimental_host requires a TLS connection. + + :type client_certificate: str + :param client_certificate: (Optional) The path to the client certificate file used for mTLS connection. + This is intended only for experimental host spanner endpoints. + If set, this will override the `api_endpoint` in `client_options`. + This is mandatory if the experimental_host requires a mTLS connection. + + :type client_key: str + :param client_key: (Optional) The path to the client key file used for mTLS connection. + This is intended only for experimental host spanner endpoints. + If set, this will override the `api_endpoint` in `client_options`. + This is mandatory if the experimental_host requires a mTLS connection. """ _instance_admin_api = None @@ -210,6 +237,10 @@ def __init__( default_transaction_options: Optional[DefaultTransactionOptions] = None, experimental_host=None, disable_builtin_metrics=False, + use_plain_text=False, + ca_certificate=None, + client_certificate=None, + client_key=None, ): self._emulator_host = _get_spanner_emulator_host() self._experimental_host = experimental_host @@ -224,6 +255,12 @@ def __init__( if self._emulator_host: credentials = AnonymousCredentials() elif self._experimental_host: + # For all experimental host endpoints project is default + project = "default" + self._use_plain_text = use_plain_text + self._ca_certificate = ca_certificate + self._client_certificate = client_certificate + self._client_key = client_key credentials = AnonymousCredentials() elif isinstance(credentials, AnonymousCredentials): self._emulator_host = self._client_options.api_endpoint @@ -259,7 +296,7 @@ def __init__( ): meter_provider = metrics.NoOpMeterProvider() try: - if not _get_spanner_emulator_host(): + if not _get_spanner_emulator_host() and not self._experimental_host: meter_provider = MeterProvider( metric_readers=[ PeriodicExportingMetricReader( @@ -339,8 +376,13 @@ def instance_admin_api(self): transport=transport, ) elif self._experimental_host: - transport = InstanceAdminGrpcTransport( - channel=grpc.insecure_channel(target=self._experimental_host) + transport = _create_experimental_host_transport( + InstanceAdminGrpcTransport, + self._experimental_host, + self._use_plain_text, + self._ca_certificate, + self._client_certificate, + self._client_key, ) self._instance_admin_api = InstanceAdminClient( client_info=self._client_info, @@ -369,8 +411,13 @@ def database_admin_api(self): transport=transport, ) elif self._experimental_host: - transport = DatabaseAdminGrpcTransport( - channel=grpc.insecure_channel(target=self._experimental_host) + transport = _create_experimental_host_transport( + DatabaseAdminGrpcTransport, + self._experimental_host, + self._use_plain_text, + self._ca_certificate, + self._client_certificate, + self._client_key, ) self._database_admin_api = DatabaseAdminClient( client_info=self._client_info, @@ -517,7 +564,6 @@ def instance( self._emulator_host, labels, processing_units, - self._experimental_host, ) def list_instances(self, filter_="", page_size=None): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 33c442602c..83c2ae8689 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -55,6 +55,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, + _create_experimental_host_transport, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -203,11 +204,9 @@ def __init__( self._pool = pool pool.bind(self) - is_experimental_host = self._instance.experimental_host is not None + self._experimental_host = self._instance._client._experimental_host - self._sessions_manager = DatabaseSessionsManager( - self, pool, is_experimental_host - ) + self._sessions_manager = DatabaseSessionsManager(self, pool) @classmethod def from_pb(cls, database_pb, instance, pool=None): @@ -452,9 +451,14 @@ def spanner_api(self): client_info=client_info, transport=transport ) return self._spanner_api - if self._instance.experimental_host is not None: - transport = SpannerGrpcTransport( - channel=grpc.insecure_channel(self._instance.experimental_host) + if self._experimental_host is not None: + transport = _create_experimental_host_transport( + SpannerGrpcTransport, + self._experimental_host, + self._instance._client._use_plain_text, + self._instance._client._ca_certificate, + self._instance._client._client_certificate, + self._instance._client._client_key, ) self._spanner_api = SpannerClient( client_info=client_info, diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index bc0db1577c..5414a64e13 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -62,10 +62,9 @@ class DatabaseSessionsManager(object): _MAINTENANCE_THREAD_POLLING_INTERVAL = timedelta(minutes=10) _MAINTENANCE_THREAD_REFRESH_INTERVAL = timedelta(days=7) - def __init__(self, database, pool, is_experimental_host: bool = False): + def __init__(self, database, pool): self._database = database self._pool = pool - self._is_experimental_host = is_experimental_host # Declare multiplexed session attributes. When a multiplexed session for the # database session manager is created, a maintenance thread is initialized to @@ -89,7 +88,8 @@ def get_session(self, transaction_type: TransactionType) -> Session: session = ( self._get_multiplexed_session() - if self._use_multiplexed(transaction_type) or self._is_experimental_host + if self._use_multiplexed(transaction_type) + or self._database._experimental_host is not None else self._pool.get() ) diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 0d05699728..a67e0e630b 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -122,7 +122,6 @@ def __init__( emulator_host=None, labels=None, processing_units=None, - experimental_host=None, ): self.instance_id = instance_id self._client = client @@ -143,7 +142,6 @@ def __init__( self._node_count = processing_units // PROCESSING_UNITS_PER_NODE self.display_name = display_name or instance_id self.emulator_host = emulator_host - self.experimental_host = experimental_host if labels is None: labels = {} self.labels = labels diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index f3f71d6e85..70a4d6bac2 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -17,6 +17,7 @@ import google.auth.credentials from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud.spanner_v1 import SpannerClient +from google.cloud.spanner_v1._helpers import _create_experimental_host_transport from google.cloud.spanner_v1.database import Database, SPANNER_DATA_SCOPE from google.cloud.spanner_v1.services.spanner.transports import ( SpannerGrpcTransport, @@ -86,12 +87,18 @@ def spanner_api(self): transport=transport, ) return self._spanner_api - if self._instance.experimental_host is not None: - channel = grpc.insecure_channel(self._instance.experimental_host) + if self._experimental_host is not None: self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() self._interceptors.append(self._x_goog_request_id_interceptor) - channel = grpc.intercept_channel(channel, *self._interceptors) - transport = SpannerGrpcTransport(channel=channel) + transport = _create_experimental_host_transport( + SpannerGrpcTransport, + self._experimental_host, + self._instance._client._use_plain_text, + self._instance._client._ca_certificate, + self._instance._client._client_certificate, + self._instance._client._client_key, + self._interceptors, + ) self._spanner_api = SpannerClient( client_info=client_info, transport=transport, diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index 10f970427e..90b06aadd7 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -60,7 +60,14 @@ EXPERIMENTAL_HOST = os.getenv(USE_EXPERIMENTAL_HOST_ENVVAR) USE_EXPERIMENTAL_HOST = EXPERIMENTAL_HOST is not None -EXPERIMENTAL_HOST_PROJECT = "default" +CA_CERTIFICATE_ENVVAR = "CA_CERTIFICATE" +CA_CERTIFICATE = os.getenv(CA_CERTIFICATE_ENVVAR) +CLIENT_CERTIFICATE_ENVVAR = "CLIENT_CERTIFICATE" +CLIENT_CERTIFICATE = os.getenv(CLIENT_CERTIFICATE_ENVVAR) +CLIENT_KEY_ENVVAR = "CLIENT_KEY" +CLIENT_KEY = os.getenv(CLIENT_KEY_ENVVAR) +USE_PLAIN_TEXT = CA_CERTIFICATE is None + EXPERIMENTAL_HOST_INSTANCE = "default" DDL_STATEMENTS = ( diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 6b0ad6cebe..00e715767f 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -115,7 +115,10 @@ def spanner_client(): credentials = AnonymousCredentials() return spanner_v1.Client( - project=_helpers.EXPERIMENTAL_HOST_PROJECT, + use_plain_text=_helpers.USE_PLAIN_TEXT, + ca_certificate=_helpers.CA_CERTIFICATE, + client_certificate=_helpers.CLIENT_CERTIFICATE, + client_key=_helpers.CLIENT_KEY, credentials=credentials, experimental_host=_helpers.EXPERIMENTAL_HOST, ) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 309f533170..39420f2e2d 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -1442,6 +1442,10 @@ def test_user_agent(self, shared_instance, dbapi_database): experimental_host=_helpers.EXPERIMENTAL_HOST if _helpers.USE_EXPERIMENTAL_HOST else None, + use_plain_text=_helpers.USE_PLAIN_TEXT, + ca_certificate=_helpers.CA_CERTIFICATE, + client_certificate=_helpers.CLIENT_CERTIFICATE, + client_key=_helpers.CLIENT_KEY, ) assert ( conn.instance._client._client_info.user_agent diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 5fd2b74a17..2e0c19fc8c 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -55,6 +55,10 @@ def test_w_implicit(self, mock_client): client_info=mock.ANY, client_options=mock.ANY, route_to_leader_enabled=True, + use_plain_text=False, + ca_certificate=None, + client_certificate=None, + client_key=None, ) self.assertIs(connection.database, database) @@ -97,6 +101,10 @@ def test_w_explicit(self, mock_client): client_info=mock.ANY, client_options=mock.ANY, route_to_leader_enabled=False, + use_plain_text=False, + ca_certificate=None, + client_certificate=None, + client_key=None, ) client_info = mock_client.call_args_list[0][1]["client_info"] self.assertEqual(client_info.user_agent, USER_AGENT) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 92001fb52c..c4e23ecda4 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3543,6 +3543,8 @@ def __init__( self.credentials.expiry = None self.credentials.valid = True + self._experimental_host = None + # Mock the spanner API to return proper session names self._spanner_api = mock.Mock() @@ -3560,14 +3562,11 @@ def _next_nth_request(self): class _Instance(object): - def __init__( - self, name, client=_Client(), emulator_host=None, experimental_host=None - ): + def __init__(self, name, client=_Client(), emulator_host=None): self.name = name self.instance_id = name.rsplit("/", 1)[1] self._client = client self.emulator_host = emulator_host - self.experimental_host = experimental_host class _Backup(object): diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f3bf6726c0..9d562a6416 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1023,6 +1023,7 @@ def __init__(self, project, timeout_seconds=None): self.route_to_leader_enabled = True self.directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + self._experimental_host = None def copy(self): from copy import deepcopy