From 6d9df4b5351d7a7f0405e4615ade052ed94b9f25 Mon Sep 17 00:00:00 2001 From: Pathways-on-Cloud Team Date: Wed, 21 Jan 2026 15:21:46 -0800 Subject: [PATCH] Ensure that the proxy client exits cleanly PiperOrigin-RevId: 859286590 --- .../shared_pathways_service/isc_pathways.py | 29 ++++++++++++++++++- .../run_connect_example.py | 19 ++++++++++++ .../shared_pathways_service/validators.py | 16 ++++++++++ .../yamls/pw-proxy.yaml | 2 +- 4 files changed, 64 insertions(+), 2 deletions(-) diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 9544e00..fe2ae06 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -2,6 +2,7 @@ from collections.abc import Iterator, Mapping import contextlib +import gc import logging import os import random @@ -10,6 +11,7 @@ from typing import Any import jax +import jax.extend.backend as jax_backend import pathwaysutils from pathwaysutils.experimental.shared_pathways_service import gke_utils from pathwaysutils.experimental.shared_pathways_service import validators @@ -27,6 +29,7 @@ _JAX_PLATFORM_PROXY = "proxy" _JAX_BACKEND_TARGET_KEY = "jax_backend_target" _JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost" +_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe" _logger = logging.getLogger(__name__) @@ -36,6 +39,7 @@ def _deploy_pathways_proxy_server( proxy_job_name: str, expected_instances: Mapping[Any, Any], gcs_scratch_location: str, + proxy_server_image: str, ) -> None: """Deploys the Pathways proxy pods to the GKE cluster. @@ -45,6 +49,7 @@ def _deploy_pathways_proxy_server( expected_instances: A dictionary mapping instance types to the number of instances. gcs_scratch_location: The Google Cloud Storage location to use. + proxy_server_image: The image to use for the proxy server. Raises: subprocess.CalledProcessError: If the kubectl command fails. @@ -70,6 +75,7 @@ def _deploy_pathways_proxy_server( PATHWAYS_HEAD_PORT=pathways_head_port, EXPECTED_INSTANCES=instances_str, GCS_SCRATCH_LOCATION=gcs_scratch_location, + PROXY_SERVER_IMAGE=proxy_server_image, ) _logger.info("Deploying Pathways proxy: %s", proxy_job_name) @@ -89,6 +95,8 @@ class _ISCPathways: pathways_service: The service name and port of the Pathways head pod. expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. + proxy_job_name: The name to use for the deployed proxy. + proxy_server_image: The image to use for the proxy server. """ def __init__( @@ -100,6 +108,7 @@ def __init__( pathways_service: str, expected_tpu_instances: Mapping[Any, Any], proxy_job_name: str | None, + proxy_server_image: str, ): """Initializes the TPU manager.""" self.cluster = cluster @@ -115,6 +124,7 @@ def __init__( self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}" self._port_forward_process = None self._proxy_port = None + self.proxy_server_image = proxy_server_image def __repr__(self): return ( @@ -133,6 +143,7 @@ def __enter__(self): proxy_job_name=self._proxy_job_name, expected_instances=self.expected_tpu_instances, gcs_scratch_location=self.bucket, + proxy_server_image=self.proxy_server_image, ) # Print a link to Cloud Logging cloud_logging_link = gke_utils.get_log_link( @@ -172,7 +183,16 @@ def __exit__(self, exc_type, exc_value, traceback): def _cleanup(self): """Cleans up resources created by the ISCPathways context.""" + # 1. Clear JAX caches and run garbage collection. + _logger.info("Starting Pathways proxy cleanup.") + jax_backend.clear_backends() + jax.clear_caches() + gc.collect() + _logger.info("Cleared JAX caches and ran garbage collection.") + + # 2. Terminate the port forwarding process. if self._port_forward_process: + _logger.info("Terminating port forwarding process...") self._port_forward_process.terminate() try: self._port_forward_process.wait(timeout=10) @@ -183,8 +203,10 @@ def _cleanup(self): e, ) - _logger.info("Deleting Pathways proxy") + # 3. Delete the proxy GKE job. + _logger.info("Deleting Pathways proxy...") gke_utils.delete_gke_job(self._proxy_job_name) + _logger.info("Pathways proxy GKE job deletion complete.") @contextlib.contextmanager @@ -196,6 +218,7 @@ def connect( pathways_service: str, expected_tpu_instances: Mapping[str, int], proxy_job_name: str | None = None, + proxy_server_image: str | None = _DEFAULT_PROXY_IMAGE, ) -> Iterator["_ISCPathways"]: """Connects to a Pathways server if the cluster exists. If not, creates it. @@ -209,6 +232,8 @@ def connect( of instances. For example: {"tpuv6e:2x2": 2} proxy_job_name: The name to use for the deployed proxy. If not provided, a random name will be generated. + proxy_server_image: The proxy server image to use. If not provided, a + default will be used. Yields: The Pathways manager. @@ -216,6 +241,7 @@ def connect( _logger.info("Validating Pathways service and TPU instances...") validators.validate_pathways_service(pathways_service) validators.validate_tpu_instances(expected_tpu_instances) + validators.validate_proxy_server_image(proxy_server_image) _logger.info("Validation complete.") gke_utils.fetch_cluster_credentials( cluster_name=cluster, project_id=project, location=region @@ -229,5 +255,6 @@ def connect( pathways_service=pathways_service, expected_tpu_instances=expected_tpu_instances, proxy_job_name=proxy_job_name, + proxy_server_image=proxy_server_image, ) as t: yield t diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py index e2092cd..38d4778 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -24,6 +24,17 @@ "tpu_type", "tpuv6e:2x2", "The TPU machine type and topology." ) flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.") +flags.DEFINE_string( + "proxy_job_name", + None, + "The name to use for the deployed proxy. If not provided, a random name" + " will be generated.", +) +flags.DEFINE_string( + "proxy_server_image", + None, + "The proxy server image to use. If not provided, a default will be used.", +) flags.mark_flags_as_required([ "cluster", @@ -37,6 +48,13 @@ def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") + + kwargs = {} + if FLAGS.proxy_job_name: + kwargs["proxy_job_name"] = FLAGS.proxy_job_name + if FLAGS.proxy_server_image: + kwargs["proxy_server_image"] = FLAGS.proxy_server_image + with isc_pathways.connect( cluster=FLAGS.cluster, project=FLAGS.project, @@ -44,6 +62,7 @@ def main(argv: Sequence[str]) -> None: gcs_bucket=FLAGS.gcs_bucket, pathways_service=FLAGS.pathways_service, expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count}, + **kwargs, ): orig_matrix = jnp.zeros(5) result_matrix = orig_matrix + 1 diff --git a/pathwaysutils/experimental/shared_pathways_service/validators.py b/pathwaysutils/experimental/shared_pathways_service/validators.py index 1b7eae6..0276d5a 100644 --- a/pathwaysutils/experimental/shared_pathways_service/validators.py +++ b/pathwaysutils/experimental/shared_pathways_service/validators.py @@ -89,3 +89,19 @@ def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None: inst = next(iter(expected_tpu_instances.keys())) _validate_tpu_supported(inst) + + +def validate_proxy_server_image(proxy_server_image: str) -> None: + """Validates the proxy server image format.""" + if not proxy_server_image or not proxy_server_image.strip(): + raise ValueError("Proxy server image cannot be empty.") + if "/" not in proxy_server_image: + raise ValueError( + f"Proxy server image '{proxy_server_image}' must contain '/', " + "separating the registry or namespace from the final image name." + ) + if ":" not in proxy_server_image and "@" not in proxy_server_image: + raise ValueError( + f"Proxy server image '{proxy_server_image}' must contain a tag with ':'" + " or a digest with '@'." + ) diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml index e91c795..a120c23 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml @@ -14,7 +14,7 @@ spec: automountServiceAccountToken: false containers: - name: pathways-proxy - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe + image: ${PROXY_SERVER_IMAGE} imagePullPolicy: Always args: - --server_port=${PROXY_SERVER_PORT}