diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 05d96e0b2..7b961b036 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -369,7 +369,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) num_tp_workers = int( - self.parallel_config.tensor_parallel_size // + self.vllm_config.sharding_config.total_devices // num_tpu_per_worker) for tp_rank in range(num_tp_workers): # PP=2, TP=4, num_tpu_per_worker=2