Skip to content

Commit 14e95a9

Browse files
committed
REVERT: hardcode torch.mps
1 parent 99e085b commit 14e95a9

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

array_api_tests/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
__all__ = ["xp", "api_version", "xps"]
1111

12+
"""
1213
1314
# You can comment the following out and instead import the specific array module
1415
# you want to test, e.g. `import array_api_strict as xp`.
@@ -40,6 +41,12 @@
4041
"ARRAY_API_TESTS_MODULE environment variable."
4142
)
4243
44+
"""
45+
46+
import array_api_compat.torch as xp
47+
xp_name = xp.__name__
48+
49+
xp.set_default_device("mps")
4350

4451
# If xp.bool is not available, like in some versions of NumPy and CuPy, try
4552
# patching in xp.bool_.

0 commit comments

Comments
 (0)