-
Notifications
You must be signed in to change notification settings - Fork 53
Description
I find myself wanting to programmatically find out what the "highest precision float type" is that a particular library supports on a particular device. Concretely pytorch and the MPS device (their name for the GPU in a Apple M1 (and M2?)). On the MPS device they don't support float64 which is how I ended up wanting something to let me find out what the highest precision available float type is.
import torch
import array_api_compat
x = torch.tensor([1,2,3], device="mps", dtype=torch.float32)
xp = array_api_compat.get_namespace(x)
# side quest: is there a better way to get the torch namespace?
x = xp.asarray([1,2,3], device="mps", dtype=torch.float32)
# Maybe `can_cast` is the right tool?
xp.can_cast(xp.float32, xp.float64) # -> True
xp.can_cast(x, xp.float64) # -> TruePresumably the two calls to can_cast return True because in general PyTorch supports float64 and the implementation of can_cast does not inspect the device of x? So at least for now/how it is currently implemented I think can_cast is not the right tool for finding out if float64 exists and using float32 if not. Or making my own highest_precision_float().