diff --git a/lib/Service.py b/lib/Service.py index 544f85b..9cc4026 100644 --- a/lib/Service.py +++ b/lib/Service.py @@ -7,7 +7,6 @@ import json import logging import os -from contextlib import contextmanager from copy import deepcopy from time import perf_counter from typing import TypedDict @@ -33,39 +32,6 @@ class TranslateRequest(TypedDict): ctranslate2.set_random_seed(420) -@contextmanager -def translate_context(config: dict): - try: - tokenizer = SentencePieceProcessor() - tokenizer.Load(os.path.join(config["loader"]["model_path"], config["tokenizer_file"])) - - translator = ctranslate2.Translator( - **{ - # only NVIDIA GPUs are supported by CTranslate2 for now - "device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu", - **config["loader"], - } - ) - except KeyError as e: - raise ServiceException( - "Incorrect config file, ensure all required keys are present from the default config" - ) from e - except Exception as e: - raise ServiceException("Error loading the translation model") from e - - try: - start = perf_counter() - yield (tokenizer, translator) - elapsed = perf_counter() - start - logger.info(f"time taken: {elapsed:.2f}s") - except Exception as e: - raise ServiceException("Error translating the input text") from e - finally: - del tokenizer - # todo: offload to cpu? - del translator - - class Service: def __init__(self, config: dict): global logger @@ -94,12 +60,34 @@ def load_config(self, config: dict): self.config = config_copy + def load_model(self): + try: + self.tokenizer = SentencePieceProcessor() + self.tokenizer.Load(os.path.join(self.config["loader"]["model_path"], self.config["tokenizer_file"])) + + self.translator = ctranslate2.Translator( + **{ + "device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu", + **self.config["loader"], + } + ) + except KeyError as e: + raise ServiceException( + "Incorrect config file, ensure all required keys are present from the default config" + ) from e + except Exception as e: + raise ServiceException("Error loading the translation model") from e + def translate(self, data: TranslateRequest) -> str: logger.debug(f"translating text to: {data['target_language']}") - with translate_context(self.config) as (tokenizer, translator): - input_tokens = tokenizer.Encode(f"<2{data['target_language']}> {clean_text(data['input'])}", out_type=str) - results = translator.translate_batch( + try: + start = perf_counter() + input_tokens = self.tokenizer.Encode( + f"<2{data['target_language']}> {clean_text(data['input'])}", + out_type=str, + ) + results = self.translator.translate_batch( [input_tokens], batch_type="tokens", **self.config["inference"], @@ -109,7 +97,11 @@ def translate(self, data: TranslateRequest) -> str: raise ServiceException("Empty result returned from translator") # todo: handle multiple hypotheses - translation = tokenizer.Decode(results[0].hypotheses[0]) + translation = self.tokenizer.Decode(results[0].hypotheses[0]) + elapsed = perf_counter() - start + logger.info(f"time taken: {elapsed:.2f}s") + except Exception as e: + raise ServiceException("Error translating the input text") from e logger.debug(f"Translated string: {translation}") return translation diff --git a/lib/main.py b/lib/main.py index c7be29d..c60024a 100644 --- a/lib/main.py +++ b/lib/main.py @@ -122,6 +122,7 @@ async def _(request: Request, exc: Exception): def task_fetch_thread(service: Service): global app_enabled + service.load_model() nc = NextcloudApp() while True: if not app_enabled.is_set():