Skip to content

Conversation

@burbajr
Copy link

@burbajr burbajr commented Dec 8, 2025

Description

For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. The previous implementation counted these devices directly, resulting in 2x the actual chip count being reported in logs.

Problem Being Solved
When running on tpu7x-8 hardware (which has 4 physical chips with 2 chiplets each), the system would log:
TPU info: ... | num_chips=8 | num_cores_per_chip=2

This was misleading because it reported 8 chips when there are only 4 physical chips. Each chip has 2 chiplets exposed as separate devices to the host.

Solution
This PR adds chiplet-aware logic to the chip counting mechanism:

  1. New helper function get_num_chiplets_per_chip() that returns 2 for tpu7x devices and 1 for all other TPU types
  2. Modified get_num_chips() to divide the device count by chiplets per chip
  3. Enhanced logging to show num_chiplets_per_chip for tpu7x devices only (to avoid confusing non-tpu7x users)

Why This is a Good Solution

  • Follows existing patterns: Uses the same pattern as get_num_cores_per_chip() for consistency
  • Backward compatible: All non-tpu7x devices get chiplets_per_chip=1, making the division a no-op
  • Minimal code changes: Only touches the necessary functions
  • Well tested: Comprehensive test coverage for all scenarios

Implementation Details
The fix uses integer division (//) which is safe because:

  • tpu7x-8: 8 devices // 2 chiplets = 4 chips ✓
  • tpu7x-4: 4 devices // 2 chiplets = 2 chips ✓
  • v5e/v6e: N devices // 1 chiplet = N chips ✓ (unchanged)

Logging is conditional - chiplet info only appears for tpu7x:
tpu7x-8 (after fix):
TPU info: ... | num_chips=4 | num_chiplets_per_chip=2 | num_cores_per_chip=2
v6e-8 (unchanged, no chiplet info to avoid confusion):
TPU info: ... | num_chips=8 | num_cores_per_chip=1

Test Coverage Added

  • test_get_num_chiplets_per_chip: Tests all TPU types including tpu7x-8, tpu7x-4, v6e-8, v5litepod, and edge cases (None, empty string)
  • test_get_num_chips_tpu7x_from_accel: Verifies tpu7x-8 with 8 /dev/accel* devices returns 4 chips
  • test_get_num_chips_tpu7x_4_from_accel: Verifies tpu7x-4 with 4 /dev/accel* devices returns 2 chips
  • test_get_num_chips_tpu7x_from_vfio: Verifies tpu7x-8 with /dev/vfio path returns 4 chips
  • test_get_num_chips_non_tpu7x_unchanged: Verifies v6e-8 still returns 8 chips (backward compatibility)

How to Test
pytest tests/test_tpu_info.py -v

Run specific new tests
pytest tests/test_tpu_info.py::test_get_num_chiplets_per_chip -v
pytest tests/test_tpu_info.py::test_get_num_chips_tpu7x_from_accel -v
pytest tests/test_tpu_info.py::test_get_num_chips_non_tpu7x_unchanged -v
All tests pass with the changes.

@burbajr burbajr requested a review from vipannalla as a code owner December 8, 2025 23:17
@burbajr burbajr force-pushed the fix/tpu7x-chip-counting branch from 68301ae to e92f0a2 Compare December 8, 2025 23:21
For TPU7x devices, each physical chip contains 2 chiplets that are
exposed to the host as separate devices. This was causing
get_num_chips() to report double the actual chip count.

Example: tpu7x-8 has 4 physical chips (8 chiplets total)
- Before: reported num_chips=8 (incorrect)
- After: reports num_chips=4 (correct)

Implementation:
- Modified get_num_chips() to detect tpu7x devices and divide the
  device count by 2 using integer division
- Added test_get_num_chips_tpu7x to verify correct chip counting

All non-tpu7x devices are unaffected by this change.

Signed-off-by: burbajr <joey.burba@gmail.com>
@burbajr burbajr force-pushed the fix/tpu7x-chip-counting branch from f70ba2e to a9d4a49 Compare December 9, 2025 22:40
# For tpu7x, each chip has 2 chiplets exposed as separate devices
tpu_type = get_tpu_type()
if tpu_type and "tpu7x" in tpu_type.lower():
return num_devices // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really a fan of hard coding the value "2" or having "tpu7x" specific logic.

Is there any logic within tpu_info that programatically fetches num_devices?

Or wouldn't a simple len(jax.devices()) return correct number of cores - regardless of having 1 chiplet or 1 chiplets?

Copy link
Author

@burbajr burbajr Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, and thank you for taking a closer look at this.

The original logic (below) for determining the number of cores per TPU family remains correct for v7, which has 2 cores per chip, so no changes are needed there.

def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2

The issue is only with chip count inference. Prior to v7, each physical chip (regardless of core counts) exposed a single logical device, so using jax.devices() or the OS device list was a reliable way to infer the number of chips, since there was a 1:1 mapping between physical chips and logical devices. The original logic in tpu_info simply counts the number of host devices to report the number of chips.

With v7, each physical chip now exposes two logical devices, which means len(jax.devices()) and the OS device list report twice as many devices as there are actual chips, leading to the current reporting bug.

I was not able to identify a reliable programmatic way to retrieve the physical TPU topology on TPU/GKE nodes, so dividing the device count by two is the most accurate approach I have found for correcting the reported chip count.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, curious if anything is still outstanding on this PR or if you have any questions. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this change feels unnecessary complicated? I think just changing our wording from "num chips" to "num cores" would be sufficient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for taking another look at this. Let me try to explain the motivation a bit more clearly.

The existing / original logic already hard codes the number of cores per chip based on the TPU family prefix, and that part is correct and unchanged. v5litepod and v6e have 1 core per chip, and other families have 2 cores per chip.

def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2

Where things get tricky is inferring the total number of chips or total number of cores from the number of JAX or OS visible devices. That mapping is no longer consistent across TPU families.

For example, using the same 2x2x1 topology:

v5p
4 chips, 2 cores per chip, 8 cores total
This exposes 4 JAX or OS devices

v7
4 chips, 2 cores per chip, 8 cores total
This exposes 8 JAX or OS devices

So simply counting len(jax.devices()) gives different results depending on the TPU generation, even though the physical topology and total core count are the same. Because of that, changing the field to represent cores without additional logic would still be inaccurate in some cases.

I do not have a strong opinion on the final approach. Whether we preserve chip semantics, rename fields, or restructure how this is reported, I mainly wanted to point out that the current logic cannot reliably calculate either total chips or total cores if it relies only on device count.

Happy to align on whatever direction makes the most sense here, and thanks for taking a look.

@burbajr
Copy link
Author

burbajr commented Dec 9, 2025

Updated implementation based on feedback to use a simpler approach:

  • Removed get_num_chiplets_per_chip() helper function
  • Simplified get_num_chips() with inline tpu7x check
  • Reverted logging changes to original simple format
  • Kept single focused test for tpu7x chip counting

"/dev/accel0", "/dev/accel1", "/dev/accel2", "/dev/accel3",
"/dev/accel4", "/dev/accel5", "/dev/accel6", "/dev/accel7"
])
def test_get_num_chips_tpu7x(mock_glob, mock_get_tpu_type):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating a separate unit test for tpu7x, can you use existing unit test but add parameter for tpu7x?

# For tpu7x, each chip has 2 chiplets exposed as separate devices
tpu_type = get_tpu_type()
if tpu_type and "tpu7x" in tpu_type.lower():
return num_devices // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this change feels unnecessary complicated? I think just changing our wording from "num chips" to "num cores" would be sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants