Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 30 additions & 38 deletions lib/Service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import logging
import os
from contextlib import contextmanager
from copy import deepcopy
from time import perf_counter
from typing import TypedDict
Expand All @@ -33,39 +32,6 @@ class TranslateRequest(TypedDict):
ctranslate2.set_random_seed(420)


@contextmanager
def translate_context(config: dict):
try:
tokenizer = SentencePieceProcessor()
tokenizer.Load(os.path.join(config["loader"]["model_path"], config["tokenizer_file"]))

translator = ctranslate2.Translator(
**{
# only NVIDIA GPUs are supported by CTranslate2 for now
"device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu",
**config["loader"],
}
)
except KeyError as e:
raise ServiceException(
"Incorrect config file, ensure all required keys are present from the default config"
) from e
except Exception as e:
raise ServiceException("Error loading the translation model") from e

try:
start = perf_counter()
yield (tokenizer, translator)
elapsed = perf_counter() - start
logger.info(f"time taken: {elapsed:.2f}s")
except Exception as e:
raise ServiceException("Error translating the input text") from e
finally:
del tokenizer
# todo: offload to cpu?
del translator


class Service:
def __init__(self, config: dict):
global logger
Expand Down Expand Up @@ -94,12 +60,34 @@ def load_config(self, config: dict):

self.config = config_copy

def load_model(self):
try:
self.tokenizer = SentencePieceProcessor()
self.tokenizer.Load(os.path.join(self.config["loader"]["model_path"], self.config["tokenizer_file"]))

self.translator = ctranslate2.Translator(
**{
"device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu",
**self.config["loader"],
}
)
except KeyError as e:
raise ServiceException(
"Incorrect config file, ensure all required keys are present from the default config"
) from e
except Exception as e:
raise ServiceException("Error loading the translation model") from e

def translate(self, data: TranslateRequest) -> str:
logger.debug(f"translating text to: {data['target_language']}")

with translate_context(self.config) as (tokenizer, translator):
input_tokens = tokenizer.Encode(f"<2{data['target_language']}> {clean_text(data['input'])}", out_type=str)
results = translator.translate_batch(
try:
start = perf_counter()
input_tokens = self.tokenizer.Encode(
f"<2{data['target_language']}> {clean_text(data['input'])}",
out_type=str,
)
results = self.translator.translate_batch(
[input_tokens],
batch_type="tokens",
**self.config["inference"],
Expand All @@ -109,7 +97,11 @@ def translate(self, data: TranslateRequest) -> str:
raise ServiceException("Empty result returned from translator")

# todo: handle multiple hypotheses
translation = tokenizer.Decode(results[0].hypotheses[0])
translation = self.tokenizer.Decode(results[0].hypotheses[0])
elapsed = perf_counter() - start
logger.info(f"time taken: {elapsed:.2f}s")
except Exception as e:
raise ServiceException("Error translating the input text") from e

logger.debug(f"Translated string: {translation}")
return translation
1 change: 1 addition & 0 deletions lib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ async def _(request: Request, exc: Exception):
def task_fetch_thread(service: Service):
global app_enabled

service.load_model()
nc = NextcloudApp()
while True:
if not app_enabled.is_set():
Expand Down