We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f72797f commit a2e9684Copy full SHA for a2e9684
tpu_inference/executors/ray_distributed_executor.py
@@ -366,7 +366,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
366
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
367
self.pp_tp_workers.append([])
368
num_tp_workers = int(
369
- self.parallel_config.tensor_parallel_size //
+ self.vllm_config.sharding_config.total_devices //
370
num_tpu_per_worker)
371
for tp_rank in range(num_tp_workers):
372
# PP=2, TP=4, num_tpu_per_worker=2
0 commit comments