diff --git a/lib/Service.py b/lib/Service.py index 60b792f..544f85b 100644 --- a/lib/Service.py +++ b/lib/Service.py @@ -17,8 +17,6 @@ from sentencepiece import SentencePieceProcessor from util import clean_text -GPU_ACCELERATED = os.getenv("COMPUTE_DEVICE", "CPU") != "CPU" - logger = logging.getLogger(os.environ["APP_ID"] + __name__) class ServiceException(Exception): @@ -43,7 +41,8 @@ def translate_context(config: dict): translator = ctranslate2.Translator( **{ - "device": "cuda" if GPU_ACCELERATED else "cpu", + # only NVIDIA GPUs are supported by CTranslate2 for now + "device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu", **config["loader"], } )