Skip to content

Commit 3ea092b

Browse files
committed
Merge branch 'main' into 94-errors-and-logging
2 parents 7966891 + 6b64253 commit 3ea092b

2 files changed

Lines changed: 26 additions & 3 deletions

File tree

src/classifai/vectorisers/huggingface.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,26 @@ class HuggingFaceVectoriser(VectoriserBase):
1414
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the specified model.
1515
model (transformers.PreTrainedModel): The Huggingface model instance.
1616
device (torch.device): The device (CPU or GPU) on which the model is loaded.
17+
tokenizer_kwargs (dict): Additional keyword arguments passed to the tokenizer.
18+
model_kwargs (dict): Additional keyword arguments passed to the model.
1719
"""
1820

19-
def __init__(self, model_name, device=None, model_revision="main"):
21+
def __init__(
22+
self,
23+
model_name,
24+
device=None,
25+
model_revision="main",
26+
tokenizer_kwargs: dict | None = None,
27+
model_kwargs: dict | None = None,
28+
):
2029
"""Initializes the HuggingfaceVectoriser with the specified model name and device.
2130
2231
Args:
2332
model_name (str): The name of the Huggingface model to use.
2433
device (torch.device, optional): The device to use for computation. Defaults to GPU if available, otherwise CPU.
2534
model_revision (str, optional): The specific model revision to use. Defaults to "main".
35+
tokenizer_kwargs (dict, optional): Additional keyword arguments to pass to the tokenizer. Defaults to None.
36+
model_kwargs (dict, optional): Additional keyword arguments to pass to the model. Defaults to None.
2637
2738
Raises:
2839
ExternalServiceError: If the model or tokenizer cannot be loaded.
@@ -33,9 +44,17 @@ def __init__(self, model_name, device=None, model_revision="main"):
3344
from transformers import AutoModel, AutoTokenizer # type: ignore
3445

3546
self.model_name = model_name
47+
48+
tokenizer_kwargs = dict(tokenizer_kwargs or {})
49+
model_kwargs = dict(model_kwargs or {})
50+
51+
# Ensure consistent behavior unless user overrides it
52+
tokenizer_kwargs.setdefault("trust_remote_code", False)
53+
model_kwargs.setdefault("trust_remote_code", False)
54+
3655
try:
37-
self.tokenizer = AutoTokenizer.from_pretrained(model_name, revision=model_revision) # nosec: B615
38-
self.model = AutoModel.from_pretrained(model_name, revision=model_revision) # nosec: B615
56+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, revision=model_revision, **tokenizer_kwargs) # nosec: B615
57+
self.model = AutoModel.from_pretrained(model_name, revision=model_revision, **model_kwargs) # nosec: B615
3958
except Exception as e:
4059
raise ExternalServiceError(
4160
"Failed to load HuggingFace model/tokenizer.",

uv.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)