2727_JAX_PLATFORM_PROXY = "proxy"
2828_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2929_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
30+ _DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe"
3031
3132_logger = logging .getLogger (__name__ )
3233
@@ -36,6 +37,7 @@ def _deploy_pathways_proxy_server(
3637 proxy_job_name : str ,
3738 expected_instances : Mapping [Any , Any ],
3839 gcs_scratch_location : str ,
40+ proxy_server_image : str ,
3941) -> None :
4042 """Deploys the Pathways proxy pods to the GKE cluster.
4143
@@ -45,6 +47,7 @@ def _deploy_pathways_proxy_server(
4547 expected_instances: A dictionary mapping instance types to the number of
4648 instances.
4749 gcs_scratch_location: The Google Cloud Storage location to use.
50+ proxy_server_image: The image to use for the proxy server.
4851
4952 Raises:
5053 subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +73,7 @@ def _deploy_pathways_proxy_server(
7073 PATHWAYS_HEAD_PORT = pathways_head_port ,
7174 EXPECTED_INSTANCES = instances_str ,
7275 GCS_SCRATCH_LOCATION = gcs_scratch_location ,
76+ PROXY_SERVER_IMAGE = proxy_server_image ,
7377 )
7478
7579 _logger .info ("Deploying Pathways proxy: %s" , proxy_job_name )
@@ -89,6 +93,8 @@ class _ISCPathways:
8993 pathways_service: The service name and port of the Pathways head pod.
9094 expected_tpu_instances: A dictionary mapping TPU machine types to the number
9195 of instances.
96+ proxy_job_name: The name to use for the deployed proxy.
97+ proxy_server_image: The image to use for the proxy server.
9298 """
9399
94100 def __init__ (
@@ -100,6 +106,7 @@ def __init__(
100106 pathways_service : str ,
101107 expected_tpu_instances : Mapping [Any , Any ],
102108 proxy_job_name : str | None ,
109+ proxy_server_image : str ,
103110 ):
104111 """Initializes the TPU manager."""
105112 self .cluster = cluster
@@ -115,6 +122,7 @@ def __init__(
115122 self ._proxy_job_name = proxy_job_name or f"isc-proxy-{ user } -{ suffix } "
116123 self ._port_forward_process = None
117124 self ._proxy_port = None
125+ self .proxy_server_image = proxy_server_image
118126
119127 def __repr__ (self ):
120128 return (
@@ -133,6 +141,7 @@ def __enter__(self):
133141 proxy_job_name = self ._proxy_job_name ,
134142 expected_instances = self .expected_tpu_instances ,
135143 gcs_scratch_location = self .bucket ,
144+ proxy_server_image = self .proxy_server_image ,
136145 )
137146 # Print a link to Cloud Logging
138147 cloud_logging_link = gke_utils .get_log_link (
@@ -196,6 +205,7 @@ def connect(
196205 pathways_service : str ,
197206 expected_tpu_instances : Mapping [str , int ],
198207 proxy_job_name : str | None = None ,
208+ proxy_server_image : str | None = _DEFAULT_PROXY_IMAGE ,
199209) -> Iterator ["_ISCPathways" ]:
200210 """Connects to a Pathways server if the cluster exists. If not, creates it.
201211
@@ -209,13 +219,16 @@ def connect(
209219 of instances. For example: {"tpuv6e:2x2": 2}
210220 proxy_job_name: The name to use for the deployed proxy. If not provided, a
211221 random name will be generated.
222+ proxy_server_image: The proxy server image to use. If not provided, a
223+ default will be used.
212224
213225 Yields:
214226 The Pathways manager.
215227 """
216228 _logger .info ("Validating Pathways service and TPU instances..." )
217229 validators .validate_pathways_service (pathways_service )
218230 validators .validate_tpu_instances (expected_tpu_instances )
231+ validators .validate_proxy_server_image (proxy_server_image )
219232 _logger .info ("Validation complete." )
220233 gke_utils .fetch_cluster_credentials (
221234 cluster_name = cluster , project_id = project , location = region
@@ -229,5 +242,6 @@ def connect(
229242 pathways_service = pathways_service ,
230243 expected_tpu_instances = expected_tpu_instances ,
231244 proxy_job_name = proxy_job_name ,
245+ proxy_server_image = proxy_server_image ,
232246 ) as t :
233247 yield t
0 commit comments