diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 9c76166594..5c11c7d74d 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -32,8 +32,6 @@ dependencies = [ "smdebug_rulesconfig>=1.0.1", "schema>=0.7.5", "omegaconf>=2.1.0", - "torch>=1.9.0", - "scipy>=1.5.0", # Remote function dependencies "cloudpickle>=2.0.0", "paramiko>=2.11.0", @@ -52,6 +50,13 @@ classifiers = [ ] [project.optional-dependencies] +torch = [ + "torch>=1.9.0", + "scipy>=1.5.0", +] +all = [ + "sagemaker-core[torch]", +] codegen = [ "black>=24.3.0, <25.0.0", "pandas>=2.0.0, <3.0.0", diff --git a/sagemaker-core/src/sagemaker/core/deserializers/base.py b/sagemaker-core/src/sagemaker/core/deserializers/base.py index 4faae7db74..a053808a2c 100644 --- a/sagemaker-core/src/sagemaker/core/deserializers/base.py +++ b/sagemaker-core/src/sagemaker/core/deserializers/base.py @@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"): self.convert_npy_to_tensor = from_numpy except ImportError: - raise Exception("Unable to import pytorch.") + raise ImportError( + "torch is required for TorchTensorDeserializer. " + "Install it with: pip install 'sagemaker-core[torch]'" + ) def deserialize(self, stream, content_type="tensor/pt"): """Deserialize streamed data to TorchTensor diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index a4ecf7c1dc..14d4867067 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer): def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) - from torch import Tensor + try: + from torch import Tensor + except ImportError: + raise ImportError( + "torch is required for TorchTensorSerializer. " + "Install it with: pip install 'sagemaker-core[torch]'" + ) self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer() diff --git a/sagemaker-core/tests/unit/serializers/test_torch_optional.py b/sagemaker-core/tests/unit/serializers/test_torch_optional.py new file mode 100644 index 0000000000..64e2b4a15b --- /dev/null +++ b/sagemaker-core/tests/unit/serializers/test_torch_optional.py @@ -0,0 +1,108 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import sys +from unittest import mock + +import pytest + + +def test_torch_tensor_serializer_raises_import_error_when_torch_missing(): + """Verify TorchTensorSerializer raises ImportError with helpful message when torch is missing.""" + with mock.patch.dict(sys.modules, {"torch": None}): + # Need to reload the module to pick up the mocked import + from sagemaker.core.serializers.base import TorchTensorSerializer + + with pytest.raises(ImportError, match="pip install"): + TorchTensorSerializer() + + +def test_torch_tensor_deserializer_raises_import_error_when_torch_missing(): + """Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing.""" + with mock.patch.dict(sys.modules, {"torch": None}): + from sagemaker.core.deserializers.base import TorchTensorDeserializer + + with pytest.raises(ImportError, match="pip install"): + TorchTensorDeserializer() + + +def test_torch_tensor_serializer_works_when_torch_available(): + """Verify TorchTensorSerializer can be instantiated when torch is available.""" + torch = pytest.importorskip("torch") + from sagemaker.core.serializers.base import TorchTensorSerializer + + serializer = TorchTensorSerializer() + assert serializer.CONTENT_TYPE == "tensor/pt" + + # Test serialization of a simple tensor + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = serializer.serialize(tensor) + assert result is not None + + +def test_torch_tensor_deserializer_works_when_torch_available(): + """Verify TorchTensorDeserializer can be instantiated when torch is available.""" + pytest.importorskip("torch") + from sagemaker.core.deserializers.base import TorchTensorDeserializer + + deserializer = TorchTensorDeserializer() + assert deserializer.ACCEPT == ("tensor/pt",) + + +def test_base_serializers_importable_without_torch(): + """Verify non-torch serializers can be imported and used without torch.""" + from sagemaker.core.serializers.base import ( + CSVSerializer, + NumpySerializer, + JSONSerializer, + IdentitySerializer, + JSONLinesSerializer, + LibSVMSerializer, + DataSerializer, + StringSerializer, + ) + + # Verify they can be instantiated + assert CSVSerializer() is not None + assert NumpySerializer() is not None + assert JSONSerializer() is not None + assert IdentitySerializer() is not None + assert JSONLinesSerializer() is not None + assert LibSVMSerializer() is not None + assert DataSerializer() is not None + assert StringSerializer() is not None + + +def test_base_deserializers_importable_without_torch(): + """Verify non-torch deserializers can be imported and used without torch.""" + from sagemaker.core.deserializers.base import ( + StringDeserializer, + BytesDeserializer, + CSVDeserializer, + StreamDeserializer, + NumpyDeserializer, + JSONDeserializer, + PandasDeserializer, + JSONLinesDeserializer, + ) + + # Verify they can be instantiated + assert StringDeserializer() is not None + assert BytesDeserializer() is not None + assert CSVDeserializer() is not None + assert StreamDeserializer() is not None + assert NumpyDeserializer() is not None + assert JSONDeserializer() is not None + assert PandasDeserializer() is not None + assert JSONLinesDeserializer() is not None diff --git a/sagemaker-core/tox.ini b/sagemaker-core/tox.ini index 0e31a74d80..ecf996e8a9 100644 --- a/sagemaker-core/tox.ini +++ b/sagemaker-core/tox.ini @@ -94,7 +94,7 @@ commands = pytest {posargs} deps = -r ../requirements/extras/test_requirements.txt - ../sagemaker-core + ../sagemaker-core[torch] .[test] mock depends =