-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtranslation_model.py
More file actions
72 lines (63 loc) · 2.9 KB
/
translation_model.py
File metadata and controls
72 lines (63 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import numpy as np
from indicTrans2.inference.engine import Model, iso_to_flores
import nltk
# Define constants and configurations
CHECKPOINTS_ROOT_DIR = "indicTrans2/checkpoints"
INDIC_LANGUAGES = set(iso_to_flores)
ALLOWED_DIRECTION_STRINGS = {"en-indic", "indic-en", "indic-indic"}
DEFAULT_PIVOT_LANG = "en"
FORCE_PIVOTING = False
class LocalTranslationModel:
def __init__(self):
"""
Initialize the Local Translation Model with available checkpoints.
"""
checkpoint_folders = [f.path for f in os.scandir(CHECKPOINTS_ROOT_DIR) if f.is_dir()]
if not checkpoint_folders:
raise RuntimeError(f"No checkpoint folders in: {CHECKPOINTS_ROOT_DIR}")
self.models = {}
for checkpoint_folder in checkpoint_folders:
direction_string = os.path.basename(checkpoint_folder)
if direction_string not in ALLOWED_DIRECTION_STRINGS:
raise ValueError(f"Invalid checkpoint folder name: {direction_string}")
self.models[direction_string] = Model(
os.path.join(checkpoint_folder, "ct2_int8_model"),
device="cpu", # Change to "cuda" if GPU is available
input_lang_code_format="iso",
model_type="ctranslate2",
)
# Handle pivoting logic
self.pivot_lang = None
if "en-indic" in self.models and "indic-en" in self.models:
if "indic-indic" not in self.models:
self.pivot_lang = DEFAULT_PIVOT_LANG
elif FORCE_PIVOTING:
del self.models["indic-indic"]
self.pivot_lang = DEFAULT_PIVOT_LANG
def get_direction_string(self, input_language_id, output_language_id):
"""
Determine the direction string based on input and output languages.
"""
if input_language_id == DEFAULT_PIVOT_LANG and output_language_id in INDIC_LANGUAGES:
return "en-indic"
elif input_language_id in INDIC_LANGUAGES:
if output_language_id == DEFAULT_PIVOT_LANG:
return "indic-en"
elif output_language_id in INDIC_LANGUAGES:
return "indic-indic"
return None
def translate(self, input_texts, input_language_id, output_language_id):
"""
Translate input texts using the appropriate model.
"""
direction_string = self.get_direction_string(input_language_id, output_language_id)
if not direction_string or direction_string not in self.models:
raise RuntimeError(f"Language pair not supported: {input_language_id}-{output_language_id}")
model = self.models[direction_string]
return model.paragraphs_batch_translate__multilingual(
[[text, input_language_id, output_language_id] for text in input_texts]
)
# Initialize the translation model globally
nltk.download("punkt_tab")
translation_model = LocalTranslationModel()