-
Notifications
You must be signed in to change notification settings - Fork 63
Fix TPU7x chip counting to account for chiplet architecture #1266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
68301ae to
e92f0a2
Compare
For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. This was causing get_num_chips() to report double the actual chip count. Example: tpu7x-8 has 4 physical chips (8 chiplets total) - Before: reported num_chips=8 (incorrect) - After: reports num_chips=4 (correct) Implementation: - Modified get_num_chips() to detect tpu7x devices and divide the device count by 2 using integer division - Added test_get_num_chips_tpu7x to verify correct chip counting All non-tpu7x devices are unaffected by this change. Signed-off-by: burbajr <joey.burba@gmail.com>
f70ba2e to
a9d4a49
Compare
| # For tpu7x, each chip has 2 chiplets exposed as separate devices | ||
| tpu_type = get_tpu_type() | ||
| if tpu_type and "tpu7x" in tpu_type.lower(): | ||
| return num_devices // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really a fan of hard coding the value "2" or having "tpu7x" specific logic.
Is there any logic within tpu_info that programatically fetches num_devices?
Or wouldn't a simple len(jax.devices()) return correct number of cores - regardless of having 1 chiplet or 1 chiplets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, and thank you for taking a closer look at this.
The original logic (below) for determining the number of cores per TPU family remains correct for v7, which has 2 cores per chip, so no changes are needed there.
def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2
The issue is only with chip count inference. Prior to v7, each physical chip (regardless of core counts) exposed a single logical device, so using jax.devices() or the OS device list was a reliable way to infer the number of chips, since there was a 1:1 mapping between physical chips and logical devices. The original logic in tpu_info simply counts the number of host devices to report the number of chips.
With v7, each physical chip now exposes two logical devices, which means len(jax.devices()) and the OS device list report twice as many devices as there are actual chips, leading to the current reporting bug.
I was not able to identify a reliable programmatic way to retrieve the physical TPU topology on TPU/GKE nodes, so dividing the device count by two is the most accurate approach I have found for correcting the reported chip count.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, curious if anything is still outstanding on this PR or if you have any questions. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this change feels unnecessary complicated? I think just changing our wording from "num chips" to "num cores" would be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for taking another look at this. Let me try to explain the motivation a bit more clearly.
The existing / original logic already hard codes the number of cores per chip based on the TPU family prefix, and that part is correct and unchanged. v5litepod and v6e have 1 core per chip, and other families have 2 cores per chip.
def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2
Where things get tricky is inferring the total number of chips or total number of cores from the number of JAX or OS visible devices. That mapping is no longer consistent across TPU families.
For example, using the same 2x2x1 topology:
v5p
4 chips, 2 cores per chip, 8 cores total
This exposes 4 JAX or OS devices
v7
4 chips, 2 cores per chip, 8 cores total
This exposes 8 JAX or OS devices
So simply counting len(jax.devices()) gives different results depending on the TPU generation, even though the physical topology and total core count are the same. Because of that, changing the field to represent cores without additional logic would still be inaccurate in some cases.
I do not have a strong opinion on the final approach. Whether we preserve chip semantics, rename fields, or restructure how this is reported, I mainly wanted to point out that the current logic cannot reliably calculate either total chips or total cores if it relies only on device count.
Happy to align on whatever direction makes the most sense here, and thanks for taking a look.
|
Updated implementation based on feedback to use a simpler approach:
|
| "/dev/accel0", "/dev/accel1", "/dev/accel2", "/dev/accel3", | ||
| "/dev/accel4", "/dev/accel5", "/dev/accel6", "/dev/accel7" | ||
| ]) | ||
| def test_get_num_chips_tpu7x(mock_glob, mock_get_tpu_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of creating a separate unit test for tpu7x, can you use existing unit test but add parameter for tpu7x?
| # For tpu7x, each chip has 2 chiplets exposed as separate devices | ||
| tpu_type = get_tpu_type() | ||
| if tpu_type and "tpu7x" in tpu_type.lower(): | ||
| return num_devices // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this change feels unnecessary complicated? I think just changing our wording from "num chips" to "num cores" would be sufficient.
Description
For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. The previous implementation counted these devices directly, resulting in 2x the actual chip count being reported in logs.
Problem Being Solved
When running on tpu7x-8 hardware (which has 4 physical chips with 2 chiplets each), the system would log:
TPU info: ... | num_chips=8 | num_cores_per_chip=2
This was misleading because it reported 8 chips when there are only 4 physical chips. Each chip has 2 chiplets exposed as separate devices to the host.
Solution
This PR adds chiplet-aware logic to the chip counting mechanism:
get_num_chiplets_per_chip()that returns 2 for tpu7x devices and 1 for all other TPU typesget_num_chips()to divide the device count by chiplets per chipnum_chiplets_per_chipfor tpu7x devices only (to avoid confusing non-tpu7x users)Why This is a Good Solution
get_num_cores_per_chip()for consistencychiplets_per_chip=1, making the division a no-opImplementation Details
The fix uses integer division (
//) which is safe because:Logging is conditional - chiplet info only appears for tpu7x:
tpu7x-8 (after fix):
TPU info: ... | num_chips=4 | num_chiplets_per_chip=2 | num_cores_per_chip=2
v6e-8 (unchanged, no chiplet info to avoid confusion):
TPU info: ... | num_chips=8 | num_cores_per_chip=1
Test Coverage Added
test_get_num_chiplets_per_chip: Tests all TPU types including tpu7x-8, tpu7x-4, v6e-8, v5litepod, and edge cases (None, empty string)test_get_num_chips_tpu7x_from_accel: Verifies tpu7x-8 with 8/dev/accel*devices returns 4 chipstest_get_num_chips_tpu7x_4_from_accel: Verifies tpu7x-4 with 4/dev/accel*devices returns 2 chipstest_get_num_chips_tpu7x_from_vfio: Verifies tpu7x-8 with/dev/vfiopath returns 4 chipstest_get_num_chips_non_tpu7x_unchanged: Verifies v6e-8 still returns 8 chips (backward compatibility)How to Test
pytest tests/test_tpu_info.py -v
Run specific new tests
pytest tests/test_tpu_info.py::test_get_num_chiplets_per_chip -v
pytest tests/test_tpu_info.py::test_get_num_chips_tpu7x_from_accel -v
pytest tests/test_tpu_info.py::test_get_num_chips_non_tpu7x_unchanged -v
All tests pass with the changes.