Improve pathways checkpoint load times#1345
Conversation
* Utilize a shared memory between the Jax client and pathways proxy for data heavy transfers e.g. device_puts. * Increase threads of ThreadPoolExecutor from 32 (python default) to 192. * Remove memory limit from pathways head main container. Callers should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.
| # This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625 | ||
| # This image extends GRPC timeout for long context models. | ||
| _PATHWAYS_IMAGE_TAG = "disable_settings_20250701" | ||
| _PATHWAYS_IMAGE_TAG = "shm_proxy" |
There was a problem hiding this comment.
Could you double check with Shauray that this binary includes the path of extending GRPC timeout? Or it doesn't need it anymore?
| # The flag below is needed for better H2D performance. | ||
| # Rule of thumb: 3x the shard size. So 128GB to be safe. | ||
| # Decrease if you start running out of host memory on TPU VMs. | ||
| "--tpu_premapped_buffer_size=137438953472", |
There was a problem hiding this comment.
Let's use 1/4 of the machine type's host memory and round up to the oder of 2:
https://github.com/apple/axlearn/blob/main/axlearn/cloud/gcp/system_characteristics.py#L494-L499
| self._loop_thread.start() | ||
| self._single_thread_pool = ThreadPoolExecutor(1) | ||
| self._single_thread_pool = ThreadPoolExecutor(max_workers=1) | ||
| self._multi_thread_pool = ThreadPoolExecutor(max_workers=192) |
There was a problem hiding this comment.
Can we make this a config flag? It depends on how many cpu we allocate to the head pod: https://github.com/apple/axlearn/blob/main/axlearn/cloud/gcp/pathways_utils.py#L317
| mem_req = f"{self.config.pathways_head_mem}Gi" | ||
| resources = { | ||
| "requests": {"cpu": cpu_req, "memory": mem_req}, | ||
| "limits": {"cpu": cpu_req, "memory": mem_req}, |
There was a problem hiding this comment.
For my education, what's the effect of having "request" and not "limit"?
|
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the |
|
This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue. |
Callers of deserialize should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.