diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 9544e00..cf82764 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -27,6 +27,7 @@ _JAX_PLATFORM_PROXY = "proxy" _JAX_BACKEND_TARGET_KEY = "jax_backend_target" _JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost" +_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe" _logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ def _deploy_pathways_proxy_server( proxy_job_name: str, expected_instances: Mapping[Any, Any], gcs_scratch_location: str, + proxy_server_image: str, ) -> None: """Deploys the Pathways proxy pods to the GKE cluster. @@ -45,6 +47,7 @@ def _deploy_pathways_proxy_server( expected_instances: A dictionary mapping instance types to the number of instances. gcs_scratch_location: The Google Cloud Storage location to use. + proxy_server_image: The image to use for the proxy server. Raises: subprocess.CalledProcessError: If the kubectl command fails. @@ -70,6 +73,7 @@ def _deploy_pathways_proxy_server( PATHWAYS_HEAD_PORT=pathways_head_port, EXPECTED_INSTANCES=instances_str, GCS_SCRATCH_LOCATION=gcs_scratch_location, + PROXY_SERVER_IMAGE=proxy_server_image, ) _logger.info("Deploying Pathways proxy: %s", proxy_job_name) @@ -89,6 +93,8 @@ class _ISCPathways: pathways_service: The service name and port of the Pathways head pod. expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. + proxy_job_name: The name to use for the deployed proxy. + proxy_server_image: The image to use for the proxy server. """ def __init__( @@ -100,6 +106,7 @@ def __init__( pathways_service: str, expected_tpu_instances: Mapping[Any, Any], proxy_job_name: str | None, + proxy_server_image: str, ): """Initializes the TPU manager.""" self.cluster = cluster @@ -115,6 +122,7 @@ def __init__( self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}" self._port_forward_process = None self._proxy_port = None + self.proxy_server_image = proxy_server_image def __repr__(self): return ( @@ -133,6 +141,7 @@ def __enter__(self): proxy_job_name=self._proxy_job_name, expected_instances=self.expected_tpu_instances, gcs_scratch_location=self.bucket, + proxy_server_image=self.proxy_server_image, ) # Print a link to Cloud Logging cloud_logging_link = gke_utils.get_log_link( @@ -196,6 +205,7 @@ def connect( pathways_service: str, expected_tpu_instances: Mapping[str, int], proxy_job_name: str | None = None, + proxy_server_image: str | None = _DEFAULT_PROXY_IMAGE, ) -> Iterator["_ISCPathways"]: """Connects to a Pathways server if the cluster exists. If not, creates it. @@ -209,6 +219,8 @@ def connect( of instances. For example: {"tpuv6e:2x2": 2} proxy_job_name: The name to use for the deployed proxy. If not provided, a random name will be generated. + proxy_server_image: The proxy server image to use. If not provided, a + default will be used. Yields: The Pathways manager. @@ -216,6 +228,7 @@ def connect( _logger.info("Validating Pathways service and TPU instances...") validators.validate_pathways_service(pathways_service) validators.validate_tpu_instances(expected_tpu_instances) + validators.validate_proxy_server_image(proxy_server_image) _logger.info("Validation complete.") gke_utils.fetch_cluster_credentials( cluster_name=cluster, project_id=project, location=region @@ -229,5 +242,6 @@ def connect( pathways_service=pathways_service, expected_tpu_instances=expected_tpu_instances, proxy_job_name=proxy_job_name, + proxy_server_image=proxy_server_image, ) as t: yield t diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py index e2092cd..38d4778 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -24,6 +24,17 @@ "tpu_type", "tpuv6e:2x2", "The TPU machine type and topology." ) flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.") +flags.DEFINE_string( + "proxy_job_name", + None, + "The name to use for the deployed proxy. If not provided, a random name" + " will be generated.", +) +flags.DEFINE_string( + "proxy_server_image", + None, + "The proxy server image to use. If not provided, a default will be used.", +) flags.mark_flags_as_required([ "cluster", @@ -37,6 +48,13 @@ def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") + + kwargs = {} + if FLAGS.proxy_job_name: + kwargs["proxy_job_name"] = FLAGS.proxy_job_name + if FLAGS.proxy_server_image: + kwargs["proxy_server_image"] = FLAGS.proxy_server_image + with isc_pathways.connect( cluster=FLAGS.cluster, project=FLAGS.project, @@ -44,6 +62,7 @@ def main(argv: Sequence[str]) -> None: gcs_bucket=FLAGS.gcs_bucket, pathways_service=FLAGS.pathways_service, expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count}, + **kwargs, ): orig_matrix = jnp.zeros(5) result_matrix = orig_matrix + 1 diff --git a/pathwaysutils/experimental/shared_pathways_service/validators.py b/pathwaysutils/experimental/shared_pathways_service/validators.py index 1b7eae6..0276d5a 100644 --- a/pathwaysutils/experimental/shared_pathways_service/validators.py +++ b/pathwaysutils/experimental/shared_pathways_service/validators.py @@ -89,3 +89,19 @@ def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None: inst = next(iter(expected_tpu_instances.keys())) _validate_tpu_supported(inst) + + +def validate_proxy_server_image(proxy_server_image: str) -> None: + """Validates the proxy server image format.""" + if not proxy_server_image or not proxy_server_image.strip(): + raise ValueError("Proxy server image cannot be empty.") + if "/" not in proxy_server_image: + raise ValueError( + f"Proxy server image '{proxy_server_image}' must contain '/', " + "separating the registry or namespace from the final image name." + ) + if ":" not in proxy_server_image and "@" not in proxy_server_image: + raise ValueError( + f"Proxy server image '{proxy_server_image}' must contain a tag with ':'" + " or a digest with '@'." + ) diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml index e91c795..a120c23 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml @@ -14,7 +14,7 @@ spec: automountServiceAccountToken: false containers: - name: pathways-proxy - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe + image: ${PROXY_SERVER_IMAGE} imagePullPolicy: Always args: - --server_port=${PROXY_SERVER_PORT}