We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 99e085b commit 14e95a9Copy full SHA for 14e95a9
1 file changed
array_api_tests/__init__.py
@@ -9,6 +9,7 @@
9
10
__all__ = ["xp", "api_version", "xps"]
11
12
+"""
13
14
# You can comment the following out and instead import the specific array module
15
# you want to test, e.g. `import array_api_strict as xp`.
@@ -40,6 +41,12 @@
40
41
"ARRAY_API_TESTS_MODULE environment variable."
42
)
43
44
45
+
46
+import array_api_compat.torch as xp
47
+xp_name = xp.__name__
48
49
+xp.set_default_device("mps")
50
51
# If xp.bool is not available, like in some versions of NumPy and CuPy, try
52
# patching in xp.bool_.
0 commit comments