Skip to content

Commit 0639a12

Browse files
authored
Add CUDA 12.9 compatibility note for DGX Spark/GB10 to 0.9.0 release … (#251)
* Add CUDA 12.9 compatibility note for DGX Spark/GB10 to 0.9.0 release notes
1 parent c294233 commit 0639a12

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
- [Torch] Removed deprecated primitive classes: `TensorProduct`, `EquivariantTensorProduct`, `SymmetricTensorProduct`, and `IWeightedSymmetricTensorProduct`. Use `cuet.SegmentedPolynomial` with `method='uniform_1d'` instead, or the high-level APIs (`cuet.ChannelWiseTensorProduct`, `cuet.FullyConnectedTensorProduct`, `cuet.SymmetricContraction`). Attempting to import these classes will raise an `ImportError` with migration instructions.
1414
- [Torch] Removed deprecated low-level wrapper classes: `TensorProductUniform1d`, `TensorProductUniform4x1d`, `TensorProductUniform3x1dIndexed`, `TensorProductUniform4x1dIndexed`, and `SymmetricTensorContraction` from `cuequivariance_ops_torch`. Use `torch.ops.cuequivariance.uniform_1d` or `cuet.SegmentedPolynomial` instead.
1515

16+
### Notes
17+
- [JAX] DGX Spark/GB10 (sm_121) with CUDA 12.9: This release uses PTX 87, which works correctly for most architectures but is not compatible with DGX Spark/GB10 on CUDA 12.9. To enable DGX Spark/GB10 support with CUDA 12.9, refer to [#250](https://github.com/NVIDIA/cuEquivariance/pull/250) for a simple frontend integration tweak that restricts PTX 88 to sm_121 only. This fix will be merged after the 0.9.0 release.
18+
1619
### Added
1720
- [Torch/JAX] New environment variable `CUEQUIVARIANCE_OPS_NVRTC_CACHE_DIR` allows setting a directory for caching compiled kernels, improving JIT compilation time for uniform_1d kernels.
1821

cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _get_max_ptx_version():
108108
# Map CUDA version to PTX version
109109
if major == 12:
110110
if minor >= 9:
111-
version = 87 # 88 breaks some triton tests
111+
version = 87 # 88 breaks some triton tests
112112
elif minor >= 8:
113113
version = 87
114114
elif minor >= 5:

0 commit comments

Comments
 (0)