diff --git a/esm/utils/misc.py b/esm/utils/misc.py index f0d7a602..8ff7e24e 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -1,5 +1,6 @@ import os from collections import defaultdict +from contextlib import nullcontext from io import BytesIO from typing import Any, ContextManager, Sequence, TypeVar from warnings import warn @@ -224,6 +225,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast """ if device_type == "cpu": return torch.amp.autocast(device_type, enabled=False) # type: ignore + elif device_type == "mps": + # For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast. + return nullcontext() elif device_type == "cuda": return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore else: