Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Iterator, Mapping
import contextlib
import gc
import logging
import os
import random
Expand All @@ -10,6 +11,7 @@
from typing import Any

import jax
import jax.extend.backend as jax_backend
import pathwaysutils
from pathwaysutils.experimental.shared_pathways_service import gke_utils
from pathwaysutils.experimental.shared_pathways_service import validators
Expand All @@ -27,6 +29,7 @@
_JAX_PLATFORM_PROXY = "proxy"
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe"

_logger = logging.getLogger(__name__)

Expand All @@ -36,6 +39,7 @@ def _deploy_pathways_proxy_server(
proxy_job_name: str,
expected_instances: Mapping[Any, Any],
gcs_scratch_location: str,
proxy_server_image: str,
) -> None:
"""Deploys the Pathways proxy pods to the GKE cluster.

Expand All @@ -45,6 +49,7 @@ def _deploy_pathways_proxy_server(
expected_instances: A dictionary mapping instance types to the number of
instances.
gcs_scratch_location: The Google Cloud Storage location to use.
proxy_server_image: The image to use for the proxy server.

Raises:
subprocess.CalledProcessError: If the kubectl command fails.
Expand All @@ -70,6 +75,7 @@ def _deploy_pathways_proxy_server(
PATHWAYS_HEAD_PORT=pathways_head_port,
EXPECTED_INSTANCES=instances_str,
GCS_SCRATCH_LOCATION=gcs_scratch_location,
PROXY_SERVER_IMAGE=proxy_server_image,
)

_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
Expand All @@ -89,6 +95,8 @@ class _ISCPathways:
pathways_service: The service name and port of the Pathways head pod.
expected_tpu_instances: A dictionary mapping TPU machine types to the number
of instances.
proxy_job_name: The name to use for the deployed proxy.
proxy_server_image: The image to use for the proxy server.
"""

def __init__(
Expand All @@ -100,6 +108,7 @@ def __init__(
pathways_service: str,
expected_tpu_instances: Mapping[Any, Any],
proxy_job_name: str | None,
proxy_server_image: str,
):
"""Initializes the TPU manager."""
self.cluster = cluster
Expand All @@ -115,6 +124,7 @@ def __init__(
self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}"
self._port_forward_process = None
self._proxy_port = None
self.proxy_server_image = proxy_server_image

def __repr__(self):
return (
Expand All @@ -133,6 +143,7 @@ def __enter__(self):
proxy_job_name=self._proxy_job_name,
expected_instances=self.expected_tpu_instances,
gcs_scratch_location=self.bucket,
proxy_server_image=self.proxy_server_image,
)
# Print a link to Cloud Logging
cloud_logging_link = gke_utils.get_log_link(
Expand Down Expand Up @@ -172,7 +183,16 @@ def __exit__(self, exc_type, exc_value, traceback):

def _cleanup(self):
"""Cleans up resources created by the ISCPathways context."""
# 1. Clear JAX caches and run garbage collection.
_logger.info("Starting Pathways proxy cleanup.")
jax_backend.clear_backends()
jax.clear_caches()
gc.collect()
_logger.info("Cleared JAX caches and ran garbage collection.")

# 2. Terminate the port forwarding process.
if self._port_forward_process:
_logger.info("Terminating port forwarding process...")
self._port_forward_process.terminate()
try:
self._port_forward_process.wait(timeout=10)
Expand All @@ -183,8 +203,10 @@ def _cleanup(self):
e,
)

_logger.info("Deleting Pathways proxy")
# 3. Delete the proxy GKE job.
_logger.info("Deleting Pathways proxy...")
gke_utils.delete_gke_job(self._proxy_job_name)
_logger.info("Pathways proxy GKE job deletion complete.")


@contextlib.contextmanager
Expand All @@ -196,6 +218,7 @@ def connect(
pathways_service: str,
expected_tpu_instances: Mapping[str, int],
proxy_job_name: str | None = None,
proxy_server_image: str | None = _DEFAULT_PROXY_IMAGE,
) -> Iterator["_ISCPathways"]:
"""Connects to a Pathways server if the cluster exists. If not, creates it.

Expand All @@ -209,13 +232,16 @@ def connect(
of instances. For example: {"tpuv6e:2x2": 2}
proxy_job_name: The name to use for the deployed proxy. If not provided, a
random name will be generated.
proxy_server_image: The proxy server image to use. If not provided, a
default will be used.

Yields:
The Pathways manager.
"""
_logger.info("Validating Pathways service and TPU instances...")
validators.validate_pathways_service(pathways_service)
validators.validate_tpu_instances(expected_tpu_instances)
validators.validate_proxy_server_image(proxy_server_image)
_logger.info("Validation complete.")
gke_utils.fetch_cluster_credentials(
cluster_name=cluster, project_id=project, location=region
Expand All @@ -229,5 +255,6 @@ def connect(
pathways_service=pathways_service,
expected_tpu_instances=expected_tpu_instances,
proxy_job_name=proxy_job_name,
proxy_server_image=proxy_server_image,
) as t:
yield t
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
"tpu_type", "tpuv6e:2x2", "The TPU machine type and topology."
)
flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.")
flags.DEFINE_string(
"proxy_job_name",
None,
"The name to use for the deployed proxy. If not provided, a random name"
" will be generated.",
)
flags.DEFINE_string(
"proxy_server_image",
None,
"The proxy server image to use. If not provided, a default will be used.",
)

flags.mark_flags_as_required([
"cluster",
Expand All @@ -37,13 +48,21 @@
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

kwargs = {}
if FLAGS.proxy_job_name:
kwargs["proxy_job_name"] = FLAGS.proxy_job_name
if FLAGS.proxy_server_image:
kwargs["proxy_server_image"] = FLAGS.proxy_server_image

with isc_pathways.connect(
cluster=FLAGS.cluster,
project=FLAGS.project,
region=FLAGS.region,
gcs_bucket=FLAGS.gcs_bucket,
pathways_service=FLAGS.pathways_service,
expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count},
**kwargs,
):
orig_matrix = jnp.zeros(5)
result_matrix = orig_matrix + 1
Expand Down
16 changes: 16 additions & 0 deletions pathwaysutils/experimental/shared_pathways_service/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None:

inst = next(iter(expected_tpu_instances.keys()))
_validate_tpu_supported(inst)


def validate_proxy_server_image(proxy_server_image: str) -> None:
"""Validates the proxy server image format."""
if not proxy_server_image or not proxy_server_image.strip():
raise ValueError("Proxy server image cannot be empty.")
if "/" not in proxy_server_image:
raise ValueError(
f"Proxy server image '{proxy_server_image}' must contain '/', "
"separating the registry or namespace from the final image name."
)
if ":" not in proxy_server_image and "@" not in proxy_server_image:
raise ValueError(
f"Proxy server image '{proxy_server_image}' must contain a tag with ':'"
" or a digest with '@'."
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ spec:
automountServiceAccountToken: false
containers:
- name: pathways-proxy
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe
image: ${PROXY_SERVER_IMAGE}
imagePullPolicy: Always
args:
- --server_port=${PROXY_SERVER_PORT}
Expand Down
Loading